diff --git a/.github/actions/locate-vcvarsall-and-setup-env/action.yml b/.github/actions/locate-vcvarsall-and-setup-env/action.yml index bf1016bf2265b..e174f384caa94 100644 --- a/.github/actions/locate-vcvarsall-and-setup-env/action.yml +++ b/.github/actions/locate-vcvarsall-and-setup-env/action.yml @@ -16,8 +16,8 @@ runs: - name: Setup VCPKG uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: - vcpkg-version: '2025.04.09' - vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' + vcpkg-version: '2025.06.13' + vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' cmake-version: '3.31.6' cmake-hash: '0f1584e8666cf4a65ec514bd02afe281caabf1d45d2c963f3151c41484f457386aa03273ab25776a670be02725354ce0b46f3a5121857416da37366342a833a0' add-cmake-to-path: 'true' diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 092b6fc8f5ce5..8df0064e06a1d 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -133,8 +133,8 @@ jobs: - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: - vcpkg-version: '2025.04.09' - vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' + vcpkg-version: '2025.06.13' + vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' cmake-version: '3.31.6' cmake-hash: '42395e20b10a8e9ef3e33014f9a4eed08d46ab952e02d2c1bbc8f6133eca0d7719fb75680f9bbff6552f20fcd1b73d86860f7f39388d631f98fb6f622b37cf04' add-cmake-to-path: 'true' @@ -168,6 +168,7 @@ jobs: --build_shared_lib --cmake_generator=Ninja --build_java + --update --build --test shell: bash @@ -237,6 +238,7 @@ jobs: --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_generator=Ninja --build_java + --update --build --test shell: bash - name: Install psutil for emulator shutdown by run_android_emulator.py diff --git a/.github/workflows/linux-dnnl.yml b/.github/workflows/linux-dnnl.yml index f6e4fe5708140..da393c1af3cee 100644 --- a/.github/workflows/linux-dnnl.yml +++ b/.github/workflows/linux-dnnl.yml @@ -33,7 +33,8 @@ jobs: architecture: x64 dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu docker_image_repo: onnxruntimecpubuildpythonx64 - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --build_nuget --use_dnnl' + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --build_nuget' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' + execution_providers: 'dnnl' secrets: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index 51d008a34a964..bde704edc2b6b 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -52,8 +52,8 @@ jobs: architecture: ${{ env.buildArch }} - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: - vcpkg-version: '2025.03.19' - vcpkg-hash: '17e96169cd3f266c4716fcdc1bb728e6a64f103941ece463a2834d50694eba4fb48f30135503fd466402afa139abc847ef630733c442595d1c34979f261b0114' + vcpkg-version: '2025.06.13' + vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' cmake-version: '3.31.6' cmake-hash: '42395e20b10a8e9ef3e33014f9a4eed08d46ab952e02d2c1bbc8f6133eca0d7719fb75680f9bbff6552f20fcd1b73d86860f7f39388d631f98fb6f622b37cf04' add-cmake-to-path: 'true' diff --git a/.github/workflows/linux_migraphx_ci.yml b/.github/workflows/linux_migraphx_ci.yml new file mode 100644 index 0000000000000..ee5e8bf12d651 --- /dev/null +++ b/.github/workflows/linux_migraphx_ci.yml @@ -0,0 +1,40 @@ +# This workflow builds and tests the ONNX Runtime for Linux for migraphx EP +# It leverages a reusable workflow (`reusable_linux_build.yml`) to handle the core build and test logic +# within Docker containers, ensuring a consistent environment. +# This file is very similar to linux_ci.yml, but much simpler + + +name: Linux MigraphX CI + +on: + push: + branches: [main, 'rel-*'] + pull_request: + branches: [main, 'rel-*'] + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} + cancel-in-progress: true + +permissions: + contents: read + packages: write + attestations: write + id-token: write + +jobs: + build-linux-x64-release-migraphx: + name: Build Linux x64 Release (migraphx EP) + uses: ./.github/workflows/reusable_linux_build.yml + with: + pool_name: "onnxruntime-github-Ubuntu2204-AMD-CPU" + build_config: Release + architecture: x64 + dockerfile_path: tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile + docker_image_repo: onnxruntimetrainingmigraphx-cibuild-rocm + extra_build_flags: '--enable_training --cmake_extra_defines CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ --rocm_version=6.4 --rocm_home /opt/rocm --nccl_home /opt/rocm --enable_nccl --skip_submodule_sync' + run_tests: false + execution_providers: 'migraphx' + secrets: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index 5f90d9430342e..7532d363b19eb 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -45,8 +45,8 @@ jobs: - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: - vcpkg-version: '2025.04.09' - vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' + vcpkg-version: '2025.06.13' + vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' cmake-version: '3.31.6' cmake-hash: '42395e20b10a8e9ef3e33014f9a4eed08d46ab952e02d2c1bbc8f6133eca0d7719fb75680f9bbff6552f20fcd1b73d86860f7f39388d631f98fb6f622b37cf04' add-cmake-to-path: 'true' @@ -153,8 +153,8 @@ jobs: - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: - vcpkg-version: '2025.04.09' - vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' + vcpkg-version: '2025.06.13' + vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' cmake-version: '3.31.6' cmake-hash: '42395e20b10a8e9ef3e33014f9a4eed08d46ab952e02d2c1bbc8f6133eca0d7719fb75680f9bbff6552f20fcd1b73d86860f7f39388d631f98fb6f622b37cf04' add-cmake-to-path: 'true' @@ -193,8 +193,8 @@ jobs: - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: - vcpkg-version: '2025.04.09' - vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' + vcpkg-version: '2025.06.13' + vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' cmake-version: '3.31.6' cmake-hash: '42395e20b10a8e9ef3e33014f9a4eed08d46ab952e02d2c1bbc8f6133eca0d7719fb75680f9bbff6552f20fcd1b73d86860f7f39388d631f98fb6f622b37cf04' add-cmake-to-path: 'true' @@ -231,8 +231,8 @@ jobs: - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: - vcpkg-version: '2025.04.09' - vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' + vcpkg-version: '2025.06.13' + vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' cmake-version: '3.31.6' cmake-hash: '42395e20b10a8e9ef3e33014f9a4eed08d46ab952e02d2c1bbc8f6133eca0d7719fb75680f9bbff6552f20fcd1b73d86860f7f39388d631f98fb6f622b37cf04' add-cmake-to-path: 'true' diff --git a/.github/workflows/linux_webgpu.yml b/.github/workflows/linux_webgpu.yml new file mode 100644 index 0000000000000..08789489b12a3 --- /dev/null +++ b/.github/workflows/linux_webgpu.yml @@ -0,0 +1,101 @@ +name: Linux WebGPU CI + +on: + push: + branches: [main, 'rel-*'] + pull_request: + branches: [main, 'rel-*'] + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} + cancel-in-progress: true + +permissions: + contents: read + packages: write + attestations: write + id-token: write + +jobs: + build-linux-webgpu-x64-release: + name: Build Linux WebGPU x64 Release + # This job runs on a CPU node using the reusable build workflow + uses: ./.github/workflows/reusable_linux_build.yml + with: + pool_name: "onnxruntime-github-Ubuntu2204-AMD-CPU" # Build pool + build_config: Release + architecture: x64 + dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu + docker_image_repo: onnxruntimecpubuildpythonx64 + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --use_webgpu --build_java --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON' + python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' + run_tests: false + upload_build_output: true + execution_providers: 'webgpu' + secrets: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Pass token for reusable workflow needs (e.g., docker build action) + + # TODO: the following test step is currently failing. Need to fix and re-enable it + + # test-linux-webgpu-x64-release: + # name: Test Linux WebGPU x64 Release + # needs: build-linux-webgpu-x64-release + # runs-on: + # - self-hosted + # - "1ES.Pool=Onnxruntime-github-Linux-GPU-A100-WUS3" + # permissions: + # contents: read + # packages: read + # steps: + # - name: Checkout code + # uses: actions/checkout@v4 + + # - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 + # id: build_docker_image_step + # with: + # dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu + # image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecpubuildcix64 + # push: true + # azure-container-registry-name: onnxruntimebuildcache + # env: + # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Pass token to action + + # # --- Download Build Artifact to Runner Temp Directory --- + # - name: Download Build Artifact + # uses: actions/download-artifact@v4 + # with: + # name: build-output-x64-Release # Must match the upload name + # path: ${{ runner.temp }}/Release # Download contents into temp dir structure + + # # --- Restore Permissions in the Temp Directory --- + # - name: Restore Executable Permissions + # if: success() # Only run if download succeeded + # working-directory: ${{ runner.temp }}/Release + # shell: bash + # run: | + # if [ -f perms.txt ]; then + # echo "Restoring executable permissions in ${{ runner.temp }}/Release ..." + # while IFS= read -r file; do + # # Check relative path existence within the current directory + # if [ -f "$file" ]; then + # chmod +x "$file" + # else + # echo "Warning: File '$file' listed in perms.txt not found." + # fi + # done < perms.txt + # echo "Permissions restored." + # else + # echo "Warning: perms.txt not found in artifact." + # fi + + # - name: Test ONNX Runtime + # id: test_step + # uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + # with: + # docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} + # build_config: Release + # mode: 'test' # Set mode to test + # execution_providers: 'webgpu' + # extra_build_flags: '--use_binskim_compliant_compile_flags --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON' + # python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' diff --git a/.lintrunner.toml b/.lintrunner.toml index 2bb6048ae4bea..7f6f61df59f3b 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -98,12 +98,15 @@ include_patterns = [ '**/*.cc', '**/*.hpp', '**/*.cpp', + '**/*.cuh', + '**/*.cu', '**/*.m', '**/*.mm', ] exclude_patterns = [ 'java/**', # FIXME: Enable clang-format for java 'onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/**', # Contains data chunks + 'onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/*.generated.cu', # Generated code 'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code 'onnxruntime/test/flatbuffers/*.fbs.h', # Generated code 'onnxruntime/core/graph/contrib_ops/quantization_defs.cc', diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index c6e16d4a3920f..a99bf8d1e4bee 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -93,7 +93,7 @@ option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF) -cmake_dependent_option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA; NOT WIN32" OFF) +option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) @@ -131,7 +131,6 @@ option(onnxruntime_ENABLE_INSTRUMENT "Enable Instrument with Event Tracing for W option(onnxruntime_USE_TELEMETRY "Build with Telemetry" OFF) cmake_dependent_option(onnxruntime_USE_MIMALLOC "Override new/delete and arena allocator with mimalloc" OFF "WIN32;NOT onnxruntime_USE_CUDA;NOT onnxruntime_USE_OPENVINO" OFF) option(onnxruntime_USE_CANN "Build with CANN support" OFF) -option(onnxruntime_USE_ROCM "Build with AMD GPU support" OFF) option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF) option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF) option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C++ interface." OFF) @@ -193,7 +192,6 @@ option(onnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER "Enable this option to run t option(onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO "Enable this option to turn on DWARF format debug info" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_PROFILING "Enable this option to turn on WebAssembly profiling and preserve function names" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL "Enable this option to allow WebAssembly to output optimized model" OFF) -option(onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64 "Enable this option to allow WebAssembly to use 64bit memory" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD "Enable WebAssembly Relaxed SIMD" OFF) # Enable bitcode for iOS @@ -223,7 +221,6 @@ option(onnxruntime_PREBUILT_PYTORCH_PATH "Path to pytorch installation dir") option(onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH "Path to external transformer src dir") option(onnxruntime_ENABLE_CUDA_PROFILING "Enable CUDA kernel profiling" OFF) -option(onnxruntime_ENABLE_ROCM_PROFILING "Enable ROCM kernel profiling" OFF) option(onnxruntime_ENABLE_CPUINFO "Enable cpuinfo" ON) @@ -232,21 +229,6 @@ option(onnxruntime_ENABLE_ATEN "Enable ATen fallback" OFF) # dlpack support cmake_dependent_option(onnxruntime_ENABLE_DLPACK "Enable dlpack" ON "onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_ATEN OR onnxruntime_ENABLE_PYTHON" OFF) - -# Triton support -option(onnxruntime_ENABLE_TRITON "Enable Triton" OFF) - -# composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set -option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON) -cmake_dependent_option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON "onnxruntime_USE_COMPOSABLE_KERNEL" OFF) -option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF) -option(onnxruntime_USE_TRITON_KERNEL "Enable triton compiled kernel" OFF) -option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF) - -option(onnxruntime_BUILD_CACHE "onnxruntime build with cache" OFF) -# https://zeux.io/2010/11/22/z7-everything-old-is-new-again/ -cmake_dependent_option(MSVC_Z7_OVERRIDE "replacing /Zi and /ZI with /Z7 when using MSVC with CCache" ON "onnxruntime_BUILD_CACHE; MSVC" OFF) - option(onnxruntime_USE_AZURE "Build with azure inferencing support" OFF) option(onnxruntime_USE_LOCK_FREE_QUEUE "Build with lock-free task queue for threadpool." OFF) option(onnxruntime_FORCE_GENERIC_ALGORITHMS "Disable optimized arch-specific algorithms. Use only for testing and debugging generic algorithms." OFF) @@ -292,176 +274,6 @@ if (onnxruntime_ENABLE_TRAINING_APIS) endif() endif() -if (onnxruntime_USE_ROCM) - if (WIN32) - message(FATAL_ERROR "ROCM does not support build in Windows!") - endif() - if (onnxruntime_USE_CUDA) - message(FATAL_ERROR "ROCM does not support build with CUDA!") - endif() - - # replicate strategy used by pytorch to get ROCM_VERSION - # https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake - # with modification - if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version") - message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version ****\n") - file(READ "${onnxruntime_ROCM_HOME}/.info/version" ROCM_VERSION_DEV_RAW) - string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW}) - elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h") - message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/include/rocm_version.h ****\n") - file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW) - string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) - elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h") - message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h ****\n") - file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW) - string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) - endif() - - if (ROCM_VERSION_MATCH) - set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) - set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) - set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) - set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") - math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") - - message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") - message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") - message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") - message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}") - message("ROCM_VERSION_DEV_INT: ${ROCM_VERSION_DEV_INT}") - else() - message(FATAL_ERROR "Cannot determine ROCm version string") - endif() - - - if (NOT CMAKE_HIP_COMPILER) - set(CMAKE_HIP_COMPILER "${onnxruntime_ROCM_HOME}/llvm/bin/clang++") - endif() - - if (NOT CMAKE_HIP_ARCHITECTURES) - if (ROCM_VERSION_DEV VERSION_LESS "6.2") - message(FATAL_ERROR "CMAKE_HIP_ARCHITECTURES is not set when ROCm version < 6.2") - else() - set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx1030;gfx1100;gfx1101;gfx940;gfx941;gfx942;gfx1200;gfx1201") - endif() - endif() - - file(GLOB rocm_cmake_components ${onnxruntime_ROCM_HOME}/lib/cmake/*) - list(APPEND CMAKE_PREFIX_PATH ${rocm_cmake_components}) - # Force cmake to accept the configured HIP compiler. Because the configured CMAKE_PREFIX_PATH does not work during - # enable_language(HIP), we might need to move configuring of CMAKE_PREFIX_PATH to build.py (in the future). - set(CMAKE_HIP_COMPILER_FORCED ON) - - enable_language(HIP) - # NOTE: Flags -mllvm -amdgpu-early-inline-all=true are critical for gpu kernel code performance. -mllvm passes the - # next flag to underlying LLVM instead of clang and -amdgpu-early-inline-all=true is the optimization flag for LLVM. - # With CMake's enable_language(HIP), additional flags including the proceeding one are propagated from - # hip-lang::device library. But in some weird cases, the hip-lang::device target may not be properly configured, for - # example, the CMAKE_PREFIX_PATH might be improperly configured. - if(NOT DEFINED _CMAKE_HIP_DEVICE_RUNTIME_TARGET) - message(FATAL_ERROR "HIP Language is not properly configured.") - endif() - add_compile_options("$<$:SHELL:-x hip>") - - if (NOT onnxruntime_HIPIFY_PERL) - find_path(HIPIFY_PERL_PATH - NAMES hipify-perl - HINTS - ${onnxruntime_ROCM_HOME}/bin - ${onnxruntime_ROCM_HOME}/hip/bin) - if (HIPIFY_PERL_PATH-NOTFOUND) - MESSAGE(FATAL_ERROR "hipify-perl not found") - endif() - MESSAGE("HIPIFY PATH: ${HIPIFY_PERL_PATH}/hipify-perl") - set(onnxruntime_HIPIFY_PERL ${HIPIFY_PERL_PATH}/hipify-perl) - endif() - - # replicate strategy used by pytorch to get ROCM_VERSION - # https://github.com/pytorch/pytorch/blob/1a10751731784942dcbb9c0524c1369a29d45244/cmake/public/LoadHIP.cmake#L45-L109 - # with modification - set(ROCM_INCLUDE_DIRS "${onnxruntime_ROCM_HOME}/include") - set(PROJECT_RANDOM_BINARY_DIR "${CMAKE_BINARY_DIR}") - set(file "${CMAKE_BINARY_DIR}/detect_rocm_version.cc") - - # Find ROCM version for checks - # ROCM 5.0 and later will have header api for version management - if(EXISTS ${ROCM_INCLUDE_DIRS}/rocm_version.h) - file(WRITE ${file} "" - "#include \n" - ) - elseif(EXISTS ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h) - file(WRITE ${file} "" - "#include \n" - ) - else() - message(FATAL_ERROR "********************* rocm_version.h couldnt be found ******************\n") - endif() - - file(APPEND ${file} "" - "#include \n" - - "#ifndef ROCM_VERSION_PATCH\n" - "#define ROCM_VERSION_PATCH 0\n" - "#endif\n" - "#define STRINGIFYHELPER(x) #x\n" - "#define STRINGIFY(x) STRINGIFYHELPER(x)\n" - "int main() {\n" - " printf(\"%d.%d.%s\", ROCM_VERSION_MAJOR, ROCM_VERSION_MINOR, STRINGIFY(ROCM_VERSION_PATCH));\n" - " return 0;\n" - "}\n" - ) - - try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} - CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" - RUN_OUTPUT_VARIABLE rocm_version_from_header - COMPILE_OUTPUT_VARIABLE output_var - ) - # We expect the compile to be successful if the include directory exists. - if(NOT compile_result) - message(FATAL_ERROR "ROCM: Couldn't determine version from header: " ${output_var}) - endif() - message(STATUS "ROCM: Header version is: " ${rocm_version_from_header}) - set(ROCM_VERSION_DEV_RAW ${rocm_version_from_header}) - - string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+).*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW}) - - if (ROCM_VERSION_DEV_MATCH) - set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) - set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) - set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) - set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") - math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") - else() - message(FATAL_ERROR "Cannot determine ROCm version string") - endif() - message("\n***** ROCm version from rocm_version.h ****\n") - message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") - message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") - message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") - message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}") - message("ROCM_VERSION_DEV_INT: ${ROCM_VERSION_DEV_INT}") - message("\n***** HIP LANGUAGE CONFIG INFO ****\n") - message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") - message("CMAKE_HIP_ARCHITECTURES: ${CMAKE_HIP_ARCHITECTURES}") - message("CMAKE_HIP_FLAGS: ${CMAKE_HIP_FLAGS}") - string(TOUPPER ${CMAKE_BUILD_TYPE} BUILD_TYPE) - message("CMAKE_HIP_FLAGS_${BUILD_TYPE}: ${CMAKE_HIP_FLAGS_${BUILD_TYPE}}") - add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT}) - - if (onnxruntime_USE_COMPOSABLE_KERNEL AND ROCM_VERSION_DEV VERSION_LESS "5.3") - message(WARNING "composable kernel is only supported on ROCm >= 5.3") - set(onnxruntime_USE_COMPOSABLE_KERNEL OFF) - set(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE OFF) - endif() - if (onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE AND ROCM_VERSION_DEV VERSION_LESS "6.0") - message(WARNING "ck_tile can only be enabled on ROCm >= 6.0 due to compatibility and compilation speed, disable automatically") - set(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE OFF) - endif() - if (onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE AND CMAKE_BUILD_TYPE STREQUAL "Debug") - message(WARNING "ck_tile hits compiler error in Debug build, disable automatically") - set(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE OFF) - endif() -endif() @@ -1048,10 +860,6 @@ if (onnxruntime_USE_ARMNN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_ARMNN=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES armnn) endif() -if (onnxruntime_USE_ROCM) - list(APPEND ORT_PROVIDER_FLAGS -DUSE_ROCM=1) - list(APPEND ONNXRUNTIME_PROVIDER_NAMES rocm) -endif() if (onnxruntime_USE_COREML) list(APPEND ORT_PROVIDER_FLAGS -DUSE_COREML=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES coreml) @@ -1292,32 +1100,6 @@ function(onnxruntime_set_compile_flags target_name) target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") endif() endif() - if (onnxruntime_USE_ROCM) - # flags are detected with CXX language mode, some flags are not supported with hipclang - # because we may mix gcc and hipclang - set(ORT_HIP_WARNING_FLAGS ${ORT_WARNING_FLAGS}) - list(REMOVE_ITEM ORT_HIP_WARNING_FLAGS -Wno-nonnull-compare) - # Unsupported by Clang 18 yet. - list(REMOVE_ITEM ORT_HIP_WARNING_FLAGS -Wno-dangling-reference) - - list(REMOVE_ITEM ORT_HIP_WARNING_FLAGS -Wno-interference-size) - # float16.h:90:12: error: ‘tmp’ is used uninitialized - list(APPEND ORT_HIP_WARNING_FLAGS -Wno-uninitialized) - list(APPEND ORT_HIP_WARNING_FLAGS -Wno-deprecated-copy) - - # some #pragma unroll will fail, do not treat them as error - # #warning must not be treated as error - list(APPEND ORT_HIP_WARNING_FLAGS -Wno-error=pass-failed "-Wno-error=#warnings") - - # otherwise error: builtin __has_trivial_assign is deprecated; use __is_trivially_assignable instead - if (ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.4") - list(APPEND ORT_HIP_WARNING_FLAGS "-Wno-deprecated-builtins") - endif() - - foreach(FLAG ${ORT_HIP_WARNING_FLAGS}) - target_compile_options(${target_name} PRIVATE "$<$:SHELL:${FLAG}>") - endforeach() - endif() endfunction() function(onnxruntime_set_source_file_properties target_name) @@ -1346,14 +1128,6 @@ function(onnxruntime_configure_target target_name) set_target_properties(${target_name} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_MINSIZEREL TRUE) endif() - if (onnxruntime_BUILD_KERNEL_EXPLORER) - get_target_property(target_type ${target_name} TYPE) - if (target_type STREQUAL "MODULE_LIBRARY" OR target_type STREQUAL "SHARED_LIBRARY") - set_property(TARGET ${target_name} - APPEND_STRING PROPERTY LINK_FLAGS " -Xlinker --version-script=${ONNXRUNTIME_ROOT}/python/tools/kernel_explorer/version_script.lds ") - endif() - endif() - # Keep BinSkim happy if(MSVC AND NOT onnxruntime_target_platform MATCHES "ARM") target_link_options(${target_name} PRIVATE "/CETCOMPAT") @@ -1652,10 +1426,6 @@ if (onnxruntime_ENABLE_CUDA_PROFILING) add_compile_definitions(ENABLE_CUDA_PROFILING) endif() -if (onnxruntime_ENABLE_ROCM_PROFILING) - add_compile_definitions(ENABLE_ROCM_PROFILING) -endif() - if (onnxruntime_ENABLE_TRAINING) add_compile_definitions(ENABLE_TRAINING_CORE) add_compile_definitions(ENABLE_STRIDED_TENSORS) @@ -1678,7 +1448,7 @@ if (UNIX OR onnxruntime_USE_NCCL) if (onnxruntime_USE_NCCL) if (onnxruntime_USE_CUDA) set(NCCL_LIBNAME "nccl") - elseif (onnxruntime_USE_ROCM OR onnxruntime_USE_MIGRAPHX) + elseif (onnxruntime_USE_MIGRAPHX) set(NCCL_LIBNAME "rccl") endif() find_path(NCCL_INCLUDE_DIR @@ -1871,11 +1641,6 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() endif() -if(onnxruntime_BUILD_KERNEL_EXPLORER) - message(STATUS "Kernel Explorer Build is enabled") - list(APPEND ONNXRUNTIME_CMAKE_FILES onnxruntime_kernel_explorer) -endif() - # When GDK_PLATFORM is set then WINAPI_FAMILY is defined in gdk_toolchain.cmake (along with other relevant flags/definitions). if (WIN32 AND NOT GDK_PLATFORM AND NOT CMAKE_CROSSCOMPILING) if (NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 6647312e99d8f..59d99ade131cd 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -50,11 +50,6 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") endif() - if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - string(APPEND CMAKE_C_FLAGS " -DORT_WASM64") - string(APPEND CMAKE_CXX_FLAGS " -DORT_WASM64") - endif() - # Build WebAssembly with multi-threads support. if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) string(APPEND CMAKE_C_FLAGS " -pthread -Wno-pthreads-mem-growth") diff --git a/cmake/deps.txt b/cmake/deps.txt index 0f2f02305a992..91bd2cef2268d 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -53,7 +53,6 @@ safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252 tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.9.2.zip;b7f8dc4a879765127ce31dfeabd31c556c80ec79 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0c12f53da76d0c31b03b9f0f8ec8f3b4.zip;239063aee4946a9af147b473a4c3da78ba7413b4 -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96 dawn;https://github.com/google/dawn/archive/9733be39e18186961d503e064874afe3e9ceb8d1.zip;2a4017c32892b90d072a9102eba90ae691fae36d diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 304a8c83959d8..20f97b96d316a 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -738,7 +738,13 @@ if (onnxruntime_USE_WEBGPU) # Dawn disabled f16 support for NVIDIA Vulkan by default because of crashes in f16 CTS tests (crbug.com/tint/2164). # Since the crashes are limited to specific GPU models, we patched Dawn to remove the restriction. # - ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch) + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch && + + # The dawn_binskim.patch contains the following changes: + # + # - (private) Fulfill the BinSkim requirements + # Some build warnings are not allowed to be disabled in project level. + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_binskim.patch) onnxruntime_fetchcontent_declare( dawn diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index c4a8641e02444..ae6684b061883 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -148,7 +148,7 @@ if(onnxruntime_BUILD_SHARED_LIB) if (APPLE) target_link_options(onnxruntime PRIVATE "LINKER:-dead_strip") elseif(NOT CMAKE_SYSTEM_NAME MATCHES "AIX") - target_link_options(onnxruntime PRIVATE "LINKER:--version-script=${SYMBOL_FILE}" "LINKER:--no-undefined" "LINKER:--gc-sections") + target_link_options(onnxruntime PRIVATE "LINKER:--version-script=${SYMBOL_FILE}" "LINKER:--no-undefined" "LINKER:--gc-sections" "LINKER:-z,noexecstack") endif() else() target_link_options(onnxruntime PRIVATE "-DEF:${SYMBOL_FILE}") diff --git a/cmake/onnxruntime_compile_triton_kernel.cmake b/cmake/onnxruntime_compile_triton_kernel.cmake deleted file mode 100644 index 9ecb8cf93265c..0000000000000 --- a/cmake/onnxruntime_compile_triton_kernel.cmake +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -find_package(Python3 COMPONENTS Interpreter REQUIRED) - -# set all triton kernel ops that need to be compiled -if(onnxruntime_USE_ROCM) - set(triton_kernel_scripts - "onnxruntime/core/providers/rocm/math/softmax_triton.py" - "onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py" - ) -endif() - -function(compile_triton_kernel out_triton_kernel_obj_file out_triton_kernel_header_dir) - # compile triton kernel, generate .a and .h files - set(triton_kernel_compiler "${REPO_ROOT}/tools/ci_build/compile_triton.py") - set(out_dir "${CMAKE_CURRENT_BINARY_DIR}/triton_kernels") - set(out_obj_file "${out_dir}/triton_kernel_infos.a") - set(header_file "${out_dir}/triton_kernel_infos.h") - - list(TRANSFORM triton_kernel_scripts PREPEND "${REPO_ROOT}/") - - add_custom_command( - OUTPUT ${out_obj_file} ${header_file} - COMMAND Python3::Interpreter ${triton_kernel_compiler} - --header ${header_file} - --script_files ${triton_kernel_scripts} - --obj_file ${out_obj_file} - DEPENDS ${triton_kernel_scripts} ${triton_kernel_compiler} - COMMENT "Triton compile generates: ${out_obj_file}" - ) - add_custom_target(onnxruntime_triton_kernel DEPENDS ${out_obj_file} ${header_file}) - set(${out_triton_kernel_obj_file} ${out_obj_file} PARENT_SCOPE) - set(${out_triton_kernel_header_dir} ${out_dir} PARENT_SCOPE) -endfunction() diff --git a/cmake/onnxruntime_csharp.cmake b/cmake/onnxruntime_csharp.cmake index 39533429e181c..cbd435cecf034 100644 --- a/cmake/onnxruntime_csharp.cmake +++ b/cmake/onnxruntime_csharp.cmake @@ -34,10 +34,6 @@ if (onnxruntime_USE_OPENVINO) STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_OPENVINO;") endif() -if (onnxruntime_USE_ROCM) - STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_ROCM;") -endif() - if (onnxruntime_USE_TENSORRT) STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_TENSORRT;") endif() diff --git a/cmake/onnxruntime_kernel_explorer.cmake b/cmake/onnxruntime_kernel_explorer.cmake deleted file mode 100644 index 65a20c4229290..0000000000000 --- a/cmake/onnxruntime_kernel_explorer.cmake +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -include(CheckLanguage) - -if(NOT onnxruntime_ENABLE_PYTHON) - message(FATAL_ERROR "python is required but is not enabled") -endif() - -set(KERNEL_EXPLORER_ROOT ${ONNXRUNTIME_ROOT}/python/tools/kernel_explorer) - -if (onnxruntime_USE_CUDA) - check_language(CUDA) - set(LANGUAGE CUDA) - set(BERT_DIR ${ONNXRUNTIME_ROOT}/contrib_ops/cuda/bert) -elseif(onnxruntime_USE_ROCM) - check_language(HIP) - set(LANGUAGE HIP) - if (onnxruntime_USE_COMPOSABLE_KERNEL) - include(composable_kernel) - endif() - if (onnxruntime_USE_HIPBLASLT) - find_package(hipblaslt REQUIRED) - endif() - set(BERT_DIR ${ONNXRUNTIME_ROOT}/contrib_ops/rocm/bert) -endif() - -file(GLOB kernel_explorer_srcs CONFIGURE_DEPENDS - "${KERNEL_EXPLORER_ROOT}/*.cc" - "${KERNEL_EXPLORER_ROOT}/*.h" -) - -file(GLOB kernel_explorer_kernel_srcs CONFIGURE_DEPENDS - "${KERNEL_EXPLORER_ROOT}/kernels/*.cc" - "${KERNEL_EXPLORER_ROOT}/kernels/*.h" - "${KERNEL_EXPLORER_ROOT}/kernels/*.cu" - "${KERNEL_EXPLORER_ROOT}/kernels/*.cuh" -) - -onnxruntime_add_shared_library_module(kernel_explorer ${kernel_explorer_srcs} ${kernel_explorer_kernel_srcs}) -set_target_properties(kernel_explorer PROPERTIES PREFIX "_") -target_include_directories(kernel_explorer PUBLIC - $ - ${KERNEL_EXPLORER_ROOT}) -target_link_libraries(kernel_explorer PRIVATE $) -target_compile_definitions(kernel_explorer PRIVATE $) -target_compile_options(kernel_explorer PRIVATE -Wno-sign-compare) - -if (onnxruntime_USE_CUDA) - file(GLOB kernel_explorer_cuda_kernel_srcs CONFIGURE_DEPENDS - "${KERNEL_EXPLORER_ROOT}/kernels/cuda/*.cc" - "${KERNEL_EXPLORER_ROOT}/kernels/cuda/*.h" - "${KERNEL_EXPLORER_ROOT}/kernels/cuda/*.cu" - "${KERNEL_EXPLORER_ROOT}/kernels/cuda/*.cuh" - ) - target_sources(kernel_explorer PRIVATE ${kernel_explorer_cuda_kernel_srcs}) - target_include_directories(kernel_explorer PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -elseif (onnxruntime_USE_ROCM) - file(GLOB kernel_explorer_rocm_kernel_srcs CONFIGURE_DEPENDS - "${KERNEL_EXPLORER_ROOT}/kernels/rocm/*.cc" - "${KERNEL_EXPLORER_ROOT}/kernels/rocm/*.h" - "${KERNEL_EXPLORER_ROOT}/kernels/rocm/*.cu" - "${KERNEL_EXPLORER_ROOT}/kernels/rocm/*.cuh" - ) - auto_set_source_files_hip_language(${kernel_explorer_kernel_srcs} ${kernel_explorer_rocm_kernel_srcs}) - target_sources(kernel_explorer PRIVATE ${kernel_explorer_rocm_kernel_srcs}) - target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1 HIPBLAS) - if (onnxruntime_USE_COMPOSABLE_KERNEL) - target_compile_definitions(kernel_explorer PRIVATE USE_COMPOSABLE_KERNEL) - if (onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE) - target_compile_definitions(kernel_explorer PRIVATE USE_COMPOSABLE_KERNEL_CK_TILE) - endif() - target_link_libraries(kernel_explorer PRIVATE onnxruntime_composable_kernel_includes) - endif() - if (onnxruntime_USE_TRITON_KERNEL) - target_compile_definitions(kernel_explorer PRIVATE USE_TRITON_KERNEL) - endif() - if (onnxruntime_USE_HIPBLASLT) - target_compile_definitions(kernel_explorer PRIVATE USE_HIPBLASLT) - endif() - if (onnxruntime_USE_ROCBLAS_EXTENSION_API) - target_compile_definitions(kernel_explorer PRIVATE USE_ROCBLAS_EXTENSION_API) - target_compile_definitions(kernel_explorer PRIVATE ROCBLAS_NO_DEPRECATED_WARNINGS) - target_compile_definitions(kernel_explorer PRIVATE ROCBLAS_BETA_FEATURES_API) - endif() -endif() - -add_dependencies(kernel_explorer onnxruntime_pybind11_state) - -enable_testing() -find_package(Python COMPONENTS Interpreter REQUIRED) -# add_test(NAME test_kernels COMMAND ${Python_EXECUTABLE} -m pytest ..) diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index e60cfbe1c0566..5ab196fdf4980 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -136,8 +136,4 @@ if (NOT onnxruntime_BUILD_SHARED_LIB) LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) -endif() - -if (onnxruntime_USE_ROCM) - add_dependencies(onnxruntime_optimizer generate_hipified_files) -endif() +endif() \ No newline at end of file diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index d1984156187f6..68ee177c88902 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -95,9 +95,7 @@ endif() if(onnxruntime_USE_ARMNN) set(PROVIDERS_ARMNN onnxruntime_providers_armnn) endif() -if(onnxruntime_USE_ROCM) - set(PROVIDERS_ROCM onnxruntime_providers_rocm) -endif() + if (onnxruntime_USE_XNNPACK) set(PROVIDERS_XNNPACK onnxruntime_providers_xnnpack) endif() @@ -188,10 +186,6 @@ if (onnxruntime_USE_ARMNN) include(onnxruntime_providers_armnn.cmake) endif() -if (onnxruntime_USE_ROCM) - include(onnxruntime_providers_rocm.cmake) -endif() - if (onnxruntime_USE_VSINPU) include(onnxruntime_providers_vsinpu.cmake) endif() diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index 5a2dfb3210988..a5f165dcbc4d3 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -25,15 +25,6 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cuh" ) -file(GLOB_RECURSE onnxruntime_rocm_contrib_ops_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.h" - "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.cc" -) - -file(GLOB_RECURSE onnxruntime_rocm_contrib_ops_cu_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.cu" - "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.cuh" -) file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/contrib_ops/js/*.h" diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 2e3589a1506d1..91707c485d3c5 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -241,18 +241,6 @@ ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) endif() - if (onnxruntime_USE_TRITON_KERNEL) - # compile triton kernel, generate .a and .h files - include(onnxruntime_compile_triton_kernel.cmake) - compile_triton_kernel(triton_kernel_obj_file triton_kernel_header_dir) - add_dependencies(${target} onnxruntime_triton_kernel) - target_compile_definitions(${target} PRIVATE USE_TRITON_KERNEL) - target_include_directories(${target} PRIVATE ${triton_kernel_header_dir}) - target_link_libraries(${target} PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive) - # lib cuda needed by cuLaunchKernel - target_link_libraries(${target} PRIVATE CUDA::cuda_driver) - endif() - include(cutlass) target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) target_link_libraries(${target} PRIVATE Eigen3::Eigen) diff --git a/cmake/onnxruntime_providers_rocm.cmake b/cmake/onnxruntime_providers_rocm.cmake deleted file mode 100644 index 03f1e288f4d0d..0000000000000 --- a/cmake/onnxruntime_providers_rocm.cmake +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - - add_definitions(-DUSE_ROCM=1) - include(onnxruntime_rocm_hipify.cmake) - - list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_ROCM_HOME}) - - find_package(HIP) - find_package(hiprand REQUIRED) - find_package(hipblas REQUIRED) - find_package(MIOpen REQUIRED) - find_package(hipfft REQUIRED) - - # MIOpen version - if(NOT DEFINED ENV{MIOPEN_PATH}) - set(MIOPEN_PATH ${onnxruntime_ROCM_HOME}) - else() - set(MIOPEN_PATH $ENV{MIOPEN_PATH}) - endif() - find_path(MIOPEN_VERSION_H_PATH - NAMES version.h - HINTS - ${MIOPEN_PATH}/include/miopen - ${MIOPEN_PATH}/miopen/include) - if (MIOPEN_VERSION_H_PATH-NOTFOUND) - MESSAGE(FATAL_ERROR "miopen version.h not found") - endif() - MESSAGE(STATUS "Found miopen version.h at ${MIOPEN_VERSION_H_PATH}") - - file(READ ${MIOPEN_VERSION_H_PATH}/version.h MIOPEN_HEADER_CONTENTS) - string(REGEX MATCH "define MIOPEN_VERSION_MAJOR * +([0-9]+)" - MIOPEN_VERSION_MAJOR "${MIOPEN_HEADER_CONTENTS}") - string(REGEX REPLACE "define MIOPEN_VERSION_MAJOR * +([0-9]+)" "\\1" - MIOPEN_VERSION_MAJOR "${MIOPEN_VERSION_MAJOR}") - string(REGEX MATCH "define MIOPEN_VERSION_MINOR * +([0-9]+)" - MIOPEN_VERSION_MINOR "${MIOPEN_HEADER_CONTENTS}") - string(REGEX REPLACE "define MIOPEN_VERSION_MINOR * +([0-9]+)" "\\1" - MIOPEN_VERSION_MINOR "${MIOPEN_VERSION_MINOR}") - string(REGEX MATCH "define MIOPEN_VERSION_PATCH * +([0-9]+)" - MIOPEN_VERSION_PATCH "${MIOPEN_HEADER_CONTENTS}") - string(REGEX REPLACE "define MIOPEN_VERSION_PATCH * +([0-9]+)" "\\1" - MIOPEN_VERSION_PATCH "${MIOPEN_VERSION_PATCH}") - set(MIOPEN_VERSION_DEV "${MIOPEN_VERSION_MAJOR}.${MIOPEN_VERSION_MINOR}.${MIOPEN_VERSION_PATCH}") - math(EXPR MIOPEN_VERSION_DEV_INT "(${MIOPEN_VERSION_MAJOR}*10000) + (${MIOPEN_VERSION_MINOR}*100) + ${MIOPEN_VERSION_PATCH}") - message("MIOPEN_VERSION_DEV: ${MIOPEN_VERSION_DEV}") - message("MIOPEN_VERSION_DEV_INT: ${MIOPEN_VERSION_DEV_INT}") - add_definitions(-DMIOPEN_VERSION=${MIOPEN_VERSION_DEV_INT}) - - find_library(RCCL_LIB rccl REQUIRED) - find_library(ROCTRACER_LIB roctracer64 REQUIRED) - find_package(rocm_smi REQUIRED) - set(ONNXRUNTIME_ROCM_LIBS roc::hipblas MIOpen hip::hipfft ${ROCM_SMI_LIBRARY} ${RCCL_LIB} ${ROCTRACER_LIB}) - include_directories(${ROCM_SMI_INCLUDE_DIR}) - link_directories(${ROCM_SMI_LIB_DIR}) - - file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.cc" - ) - - # The shared_library files are in a separate list since they use precompiled headers, and the above files have them disabled. - file(GLOB_RECURSE onnxruntime_providers_rocm_shared_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" - ) - - file(GLOB_RECURSE onnxruntime_providers_rocm_cu_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.cu" - "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.cuh" - ) - - hipify("onnxruntime/core/providers" provider_excluded_files onnxruntime_providers_rocm_generated_cc_srcs onnxruntime_providers_rocm_generated_cu_srcs) - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_rocm_cc_srcs} ${onnxruntime_providers_rocm_shared_srcs} ${onnxruntime_providers_rocm_cu_srcs}) - set(onnxruntime_providers_rocm_src ${onnxruntime_providers_rocm_cc_srcs} ${onnxruntime_providers_rocm_shared_srcs} ${onnxruntime_providers_rocm_cu_srcs}) - list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_providers_rocm_generated_cc_srcs} ${onnxruntime_providers_rocm_generated_cu_srcs}) - - # disable contrib ops conditionally - if(NOT onnxruntime_DISABLE_CONTRIB_OPS) - hipify("onnxruntime/contrib_ops" contrib_ops_excluded_files onnxruntime_rocm_generated_contrib_ops_cc_srcs onnxruntime_rocm_generated_contrib_ops_cu_srcs) - - # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio - source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_rocm_contrib_ops_cc_srcs} ${onnxruntime_rocm_contrib_ops_cu_srcs}) - list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_contrib_ops_cc_srcs} ${onnxruntime_rocm_contrib_ops_cu_srcs}) - list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_generated_contrib_ops_cc_srcs} ${onnxruntime_rocm_generated_contrib_ops_cu_srcs}) - endif() - - if (onnxruntime_ENABLE_TRAINING_OPS) - file(GLOB_RECURSE onnxruntime_rocm_training_ops_cc_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.cc" - ) - - file(GLOB_RECURSE onnxruntime_rocm_training_ops_cu_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.cu" - "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.cuh" - ) - - hipify("orttraining/orttraining/training_ops" training_ops_excluded_files onnxruntime_rocm_generated_training_ops_cc_srcs onnxruntime_rocm_generated_training_ops_cu_srcs) - - # NCCL is not support in Windows build - if (WIN32 OR NOT onnxruntime_USE_NCCL) - list(REMOVE_ITEM onnxruntime_rocm_generated_training_ops_cc_srcs - "${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining/orttraining/training_ops/rocm/collective/nccl_common.cc" - "${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining/orttraining/training_ops/rocm/collective/nccl_kernels.cc" - "${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining/orttraining/training_ops/rocm/collective/megatron.cc" - ) - endif() - - source_group(TREE ${ORTTRAINING_ROOT} FILES ${onnxruntime_rocm_training_ops_cc_srcs} ${onnxruntime_rocm_training_ops_cu_srcs}) - list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_training_ops_cc_srcs} ${onnxruntime_rocm_training_ops_cu_srcs}) - list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_generated_training_ops_cc_srcs} ${onnxruntime_rocm_generated_training_ops_cu_srcs}) - endif() - - auto_set_source_files_hip_language(${onnxruntime_providers_rocm_src}) - onnxruntime_add_shared_library_module(onnxruntime_providers_rocm ${onnxruntime_providers_rocm_src}) - target_compile_options(onnxruntime_providers_rocm PRIVATE -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1) - target_link_options(onnxruntime_providers_rocm PRIVATE -T ${REPO_ROOT}/cmake/hip_fatbin_insert) - - if(NOT MSVC) - target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-sign-compare) - target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-unused-parameter) - target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-undefined-var-template) - endif() - - onnxruntime_add_include_to_target(onnxruntime_providers_rocm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) - if (onnxruntime_ENABLE_TRAINING_OPS) - onnxruntime_add_include_to_target(onnxruntime_providers_rocm onnxruntime_training) - target_link_libraries(onnxruntime_providers_rocm PRIVATE onnxruntime_training) - if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) - onnxruntime_add_include_to_target(onnxruntime_providers_rocm Python::Module) - endif() - endif() - - add_custom_target(generate_hipified_files DEPENDS - ${onnxruntime_providers_rocm_generated_cc_srcs} - ${onnxruntime_providers_rocm_generated_cu_srcs} - ${onnxruntime_rocm_generated_contrib_ops_cc_srcs} - ${onnxruntime_rocm_generated_contrib_ops_cu_srcs} - ${onnxruntime_rocm_generated_training_ops_cc_srcs} - ${onnxruntime_rocm_generated_training_ops_cu_srcs}) - - add_dependencies(onnxruntime_providers_rocm generate_hipified_files onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(onnxruntime_providers_rocm PRIVATE ${ONNXRUNTIME_ROCM_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} ${ABSEIL_LIBS} Eigen3::Eigen) - target_include_directories(onnxruntime_providers_rocm SYSTEM - PRIVATE - ${ONNXRUNTIME_ROOT} - ${CMAKE_CURRENT_BINARY_DIR} - ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime - PUBLIC - ${onnxruntime_ROCM_HOME}/include - ${onnxruntime_ROCM_HOME}/include/roctracer) - - set_target_properties(onnxruntime_providers_rocm PROPERTIES LINKER_LANGUAGE CXX) - set_target_properties(onnxruntime_providers_rocm PROPERTIES FOLDER "ONNXRuntime") - target_compile_definitions(onnxruntime_providers_rocm PRIVATE HIPBLAS) - - if (onnxruntime_ENABLE_TRAINING) - target_include_directories(onnxruntime_providers_rocm PRIVATE ${ORTTRAINING_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining ${MPI_CXX_INCLUDE_DIRS}) - - # RCCL is enabled by default for ROCM builds - #if (onnxruntime_USE_NCCL) - # target_include_directories(onnxruntime_providers_rocm PRIVATE ${NCCL_INCLUDE_DIRS}) - # target_link_libraries(onnxruntime_providers_rocm PRIVATE ${NCCL_LIBRARIES}) - #endif() - endif() - - if (onnxruntime_USE_ROCBLAS_EXTENSION_API) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_ROCBLAS_EXTENSION_API) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_NO_DEPRECATED_WARNINGS) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_BETA_FEATURES_API) - endif() - - if (onnxruntime_USE_HIPBLASLT) - find_package(hipblaslt REQUIRED) - target_link_libraries(onnxruntime_providers_rocm PRIVATE roc::hipblaslt) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_HIPBLASLT) - endif() - - if (onnxruntime_USE_TRITON_KERNEL) - # compile triton kernel, generate .a and .h files - include(onnxruntime_compile_triton_kernel.cmake) - compile_triton_kernel(triton_kernel_obj_file triton_kernel_header_dir) - add_dependencies(onnxruntime_providers_rocm onnxruntime_triton_kernel) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_TRITON_KERNEL) - target_include_directories(onnxruntime_providers_rocm PRIVATE ${triton_kernel_header_dir}) - target_link_libraries(onnxruntime_providers_rocm PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive) - endif() - - if (onnxruntime_USE_COMPOSABLE_KERNEL) - include(composable_kernel) - target_link_libraries(onnxruntime_providers_rocm PRIVATE - onnxruntime_composable_kernel_includes - # Currently we shall not use composablekernels::device_operations, the target includes all conv dependencies, which - # are extremely slow to compile. Instead, we only link all gemm related objects. See the following directory on - # updating. - # https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/library/src/tensor_operation_instance/gpu - device_gemm_instance - device_gemm_add_fastgelu_instance - device_gemm_fastgelu_instance - device_gemm_splitk_instance - device_gemm_streamk_instance - device_batched_gemm_instance - device_softmax_instance - ) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_COMPOSABLE_KERNEL) - if (onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE) - target_link_libraries(onnxruntime_providers_rocm PUBLIC onnxruntime_composable_kernel_fmha) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_COMPOSABLE_KERNEL_CK_TILE) - endif() - endif() - - if(UNIX) - set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections") - else() - message(FATAL_ERROR "onnxruntime_providers_rocm unknown platform, need to specify shared library exports for it") - endif() - - if (onnxruntime_ENABLE_ATEN) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE ENABLE_ATEN) - endif() - file(GLOB ONNXRUNTIME_ROCM_PROVIDER_PUBLIC_HEADERS CONFIGURE_DEPENDS - "${REPO_ROOT}/include/onnxruntime/core/providers/rocm/*.h" - ) - set_target_properties(onnxruntime_providers_rocm PROPERTIES - PUBLIC_HEADER "${ONNXRUNTIME_ROCM_PROVIDER_PUBLIC_HEADERS}") - install(TARGETS onnxruntime_providers_rocm - PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers/rocm - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake deleted file mode 100644 index 111033c780712..0000000000000 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -find_package(Python3 COMPONENTS Interpreter REQUIRED) - -# GLOB pattern of file to be excluded -set(contrib_ops_excluded_files - "bert/cudnn_fmha/*" - "bert/cutlass_fmha/*" - "bert/fastertransformer_decoder_attention/*" - "bert/flash_attention/*" - "bert/tensorrt_fused_multihead_attention/*" - "bert/attention.cc" - "bert/attention.h" - "bert/attention_impl.cu" - "bert/attention_softmax.h" - "bert/attention_softmax.cu" - "bert/attention_prepare_qkv.cu" - "bert/attention_kernel_options.h" - "bert/attention_kernel_options.cc" - "bert/decoder_attention_impl.h" - "bert/decoder_attention_impl.cu" - "bert/decoder_masked_multihead_attention.h" - "bert/decoder_masked_multihead_attention.cc" - "bert/decoder_masked_self_attention.h" - "bert/decoder_masked_self_attention.cc" - "bert/multihead_attention.cc" - "bert/multihead_attention.h" - "bert/relative_attn_bias.cc" - "bert/relative_attn_bias.h" - "bert/relative_attn_bias_impl.cu" - "bert/relative_attn_bias_impl.h" - "bert/skip_layer_norm.cc" - "bert/skip_layer_norm.h" - "bert/skip_layer_norm_impl.cu" - "bert/skip_layer_norm_impl.h" - "bert/transformer_common.h" - "bert/transformer_common.cc" - "bert/packed_attention.h" - "bert/packed_attention.cc" - "bert/packed_attention_impl.h" - "bert/packed_attention_impl.cu" - "bert/packed_multihead_attention.h" - "bert/packed_multihead_attention.cc" - "bert/packed_multihead_attention_impl.h" - "bert/packed_multihead_attention_impl.cu" - "diffusion/group_norm_impl.cu" - "diffusion/nhwc_conv.cc" - "math/gemm_float8.cc" - "math/gemm_float8.cu" - "math/gemm_float8.h" - "moe/*" - "sparse/*" - "quantization/attention_quantization.cc" - "quantization/attention_quantization.h" - "quantization/attention_quantization_impl.cu" - "quantization/attention_quantization_impl.cuh" - "quantization/dequantize_blockwise_bnb4.cuh" - "quantization/dequantize_blockwise_bnb4.cu" - "quantization/matmul_bnb4.cc" - "quantization/matmul_bnb4.cuh" - "quantization/matmul_bnb4.cu" - "quantization/moe_quantization.h" - "quantization/moe_quantization.cc" - "quantization/quantize_dequantize_linear.cc" - "quantization/qordered_ops/qordered_attention_impl.cu" - "quantization/qordered_ops/qordered_attention_impl.h" - "quantization/qordered_ops/qordered_attention_input_enum.h" - "quantization/qordered_ops/qordered_attention.cc" - "quantization/qordered_ops/qordered_attention.h" - "quantization/qordered_ops/qordered_common.cuh" - "quantization/qordered_ops/qordered_layer_norm.h" - "quantization/qordered_ops/qordered_layer_norm.cc" - "quantization/qordered_ops/qordered_layer_norm_impl.h" - "quantization/qordered_ops/qordered_layer_norm_impl.cu" - "quantization/qordered_ops/qordered_longformer_attention.cc" - "quantization/qordered_ops/qordered_longformer_attention.h" - "quantization/qordered_ops/qordered_matmul.h" - "quantization/qordered_ops/qordered_matmul.cc" - "quantization/qordered_ops/qordered_matmul_utils.h" - "quantization/qordered_ops/qordered_matmul_utils.cc" - "quantization/qordered_ops/qordered_qdq_impl.cu" - "quantization/qordered_ops/qordered_qdq_impl.h" - "quantization/qordered_ops/qordered_qdq.cc" - "quantization/qordered_ops/qordered_qdq.h" - "quantization/qordered_ops/qordered_unary_ops.h" - "quantization/qordered_ops/qordered_unary_ops.cc" - "quantization/qordered_ops/qordered_unary_ops_impl.h" - "quantization/qordered_ops/qordered_unary_ops_impl.cu" - "cuda_contrib_kernels.cc" - "cuda_contrib_kernels.h" - "inverse.cc" - "fused_conv.cc" - "bert/group_query_attention.h" - "bert/group_query_attention.cc" - "bert/group_query_attention_impl.h" - "bert/group_query_attention_impl.cu" - "collective/custom_*" - "collective/distributed_*" - "collective/ipc_*" - "collective/shard*" -) - -if (NOT onnxruntime_USE_NCCL) - # Those are string patterns to exclude. Do NOT use stars such as - # collective/*.cc or *.h. - list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc") -endif() - -if (NOT onnxruntime_ENABLE_ATEN) - list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc") -endif() - -set(provider_excluded_files - "atomic/common.cuh" - "cu_inc/common.cuh" - "math/einsum_utils/einsum_auxiliary_ops.cc" - "math/einsum_utils/einsum_auxiliary_ops.h" - "math/einsum_utils/einsum_auxiliary_ops_diagonal.cu" - "math/einsum_utils/einsum_auxiliary_ops_diagonal.h" - "math/einsum.cc" - "math/einsum.h" - "math/gemm.cc" - "math/matmul.cc" - "math/softmax_impl.cu" - "math/softmax_warpwise_impl.cuh" - "math/softmax_common.cc" - "math/softmax_common.h" - "math/softmax.cc" - "math/softmax.h" - "nn/conv.cc" - "nn/conv.h" - "nn/conv_transpose.cc" - "nn/conv_transpose.h" - "nn/pool.cc" - "nn/pool.h" - "reduction/reduction_ops.cc" - "rnn/cudnn_rnn_base.cc" - "rnn/cudnn_rnn_base.h" - "rnn/gru.cc" - "rnn/gru.h" - "rnn/lstm.cc" - "rnn/lstm.h" - "rnn/rnn.cc" - "rnn/rnn.h" - "rnn/rnn_impl.cu" - "rnn/rnn_impl.h" - "shared_inc/cuda_call.h" - "shared_inc/cudnn_fe_call.h" - "shared_inc/fpgeneric.h" - "cuda_allocator.cc" - "cuda_allocator.h" - "cuda_call.cc" - "cuda_common.cc" - "cuda_common.h" - "cuda_execution_provider_info.cc" - "cuda_execution_provider_info.h" - "cuda_execution_provider.cc" - "cuda_execution_provider.h" - "cuda_kernel.h" - "cuda_pch.cc" - "cuda_pch.h" - "cuda_profiler.cc" - "cuda_profiler.h" - "cuda_provider_factory.cc" - "cuda_provider_factory.h" - "cuda_stream_handle.cc", - "cuda_stream_handle.h", - "cuda_utils.cu" - "cudnn_common.cc" - "cudnn_common.h" - "cudnn_fe_call.cc" - "cupti_manager.cc" - "cupti_manager.h" - "fpgeneric.cu" - "gpu_data_transfer.cc" - "gpu_data_transfer.h" - "integer_gemm.cc" - "tunable/*" - "cuda_nhwc_kernels.cc" - "cuda_nhwc_kernels.h" -) - -set(training_ops_excluded_files - "activation/gelu_grad_impl_common.cuh" # uses custom tanh - "collective/adasum_kernels.cc" - "collective/adasum_kernels.h" - "math/div_grad.cc" # miopen API differs from cudnn, no double type support - "nn/batch_norm_grad.cc" # no double type support - "nn/batch_norm_grad.h" # miopen API differs from cudnn - "nn/batch_norm_internal.cc" # miopen API differs from cudnn, no double type support - "nn/batch_norm_internal.h" # miopen API differs from cudnn, no double type support - "nn/conv_grad.cc" - "nn/conv_grad.h" - "reduction/reduction_all.cc" # deterministic = true, ignore ctx setting - "reduction/reduction_ops.cc" # no double type support - "cuda_training_kernels.cc" - "cuda_training_kernels.h" - "nn/conv_shared.cc" - "nn/conv_shared.h" - "nn/conv_transpose_grad.cc" - "nn/conv_transpose_grad.h" -) - -function(auto_set_source_files_hip_language) - foreach(f ${ARGN}) - if(f MATCHES ".*\\.cu$") - set_source_files_properties(${f} PROPERTIES LANGUAGE HIP) - endif() - endforeach() -endfunction() - -# cuda_dir must be relative to REPO_ROOT -function(hipify cuda_dir in_excluded_file_patterns out_generated_cc_files out_generated_cu_files) - set(hipify_tool ${REPO_ROOT}/tools/ci_build/amd_hipify.py) - - file(GLOB_RECURSE srcs CONFIGURE_DEPENDS - "${REPO_ROOT}/${cuda_dir}/cuda/*.h" - "${REPO_ROOT}/${cuda_dir}/cuda/*.cc" - "${REPO_ROOT}/${cuda_dir}/cuda/*.cuh" - "${REPO_ROOT}/${cuda_dir}/cuda/*.cu" - ) - - # do exclusion - set(excluded_file_patterns ${${in_excluded_file_patterns}}) - list(TRANSFORM excluded_file_patterns PREPEND "${REPO_ROOT}/${cuda_dir}/cuda/") - file(GLOB_RECURSE excluded_srcs CONFIGURE_DEPENDS ${excluded_file_patterns}) - foreach(f ${excluded_srcs}) - message(STATUS "Excluded from hipify: ${f}") - endforeach() - list(REMOVE_ITEM srcs ${excluded_srcs}) - - foreach(f ${srcs}) - file(RELATIVE_PATH cuda_f_rel "${REPO_ROOT}" ${f}) - string(REPLACE "cuda" "rocm" rocm_f_rel ${cuda_f_rel}) - set(f_out "${CMAKE_CURRENT_BINARY_DIR}/amdgpu/${rocm_f_rel}") - add_custom_command( - OUTPUT ${f_out} - COMMAND Python3::Interpreter ${hipify_tool} - --hipify_perl ${onnxruntime_HIPIFY_PERL} - ${f} -o ${f_out} - DEPENDS ${hipify_tool} ${f} - COMMENT "Hipify: ${cuda_f_rel} -> amdgpu/${rocm_f_rel}" - ) - if(f MATCHES ".*\\.cuh?$") - list(APPEND generated_cu_files ${f_out}) - else() - list(APPEND generated_cc_files ${f_out}) - endif() - endforeach() - - set_source_files_properties(${generated_cc_files} PROPERTIES GENERATED TRUE) - set_source_files_properties(${generated_cu_files} PROPERTIES GENERATED TRUE) - auto_set_source_files_hip_language(${generated_cu_files}) - set(${out_generated_cc_files} ${generated_cc_files} PARENT_SCOPE) - set(${out_generated_cu_files} ${generated_cu_files} PARENT_SCOPE) -endfunction() diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index d61512fa3cf09..3ec3c6ee1d5ae 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -53,12 +53,7 @@ endif() add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime") -if (onnxruntime_USE_ROCM) - target_compile_options(onnxruntime_session PRIVATE -Wno-sign-compare -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1) - target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hipcub/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include) -# ROCM provider sources are generated, need to add include directory for generated headers - target_include_directories(onnxruntime_session PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) -endif() + if (onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_session PRIVATE ${ORTTRAINING_ROOT}) endif() diff --git a/cmake/onnxruntime_training.cmake b/cmake/onnxruntime_training.cmake index f289c73071fe6..fab40da1571a9 100644 --- a/cmake/onnxruntime_training.cmake +++ b/cmake/onnxruntime_training.cmake @@ -82,10 +82,7 @@ if (onnxruntime_BUILD_UNIT_TESTS) target_include_directories(onnxruntime_training_runner PRIVATE ${NCCL_INCLUDE_DIRS}) endif() - if (onnxruntime_USE_ROCM) - add_definitions(-DUSE_ROCM=1) - target_include_directories(onnxruntime_training_runner PUBLIC ${onnxruntime_ROCM_HOME}/include) - endif() + check_cxx_compiler_flag(-Wno-maybe-uninitialized HAS_NO_MAYBE_UNINITIALIZED) if(UNIX AND NOT APPLE) @@ -94,10 +91,6 @@ if (onnxruntime_BUILD_UNIT_TESTS) endif() endif() - if (onnxruntime_USE_ROCM) - target_compile_options(onnxruntime_training_runner PUBLIC -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1) - endif() - set_target_properties(onnxruntime_training_runner PROPERTIES FOLDER "ONNXRuntimeTest") source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_training_runner_srcs} ${onnxruntime_perf_test_src}) @@ -176,11 +169,6 @@ if (onnxruntime_BUILD_UNIT_TESTS) onnxruntime_add_include_to_target(onnxruntime_training_bert onnxruntime_common onnx onnx_proto ${PROTOBUF_LIB} onnxruntime_training flatbuffers::flatbuffers Boost::mp11 safeint_interface) target_include_directories(onnxruntime_training_bert PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${ORTTRAINING_ROOT} ${MPI_CXX_INCLUDE_DIRS} ${CXXOPTS} ${extra_includes} ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/onnx onnxruntime_training_runner) - # ROCM provider sources are generated, need to add include directory for generated headers - if (onnxruntime_USE_ROCM) - target_include_directories(onnxruntime_training_bert PUBLIC ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime) - endif() - target_link_libraries(onnxruntime_training_bert PRIVATE onnxruntime_training_runner onnxruntime_training ${ONNXRUNTIME_LIBS} ${onnxruntime_EXTERNAL_LIBRARIES}) set_target_properties(onnxruntime_training_bert PROPERTIES FOLDER "ONNXRuntimeTest") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 9b33313b6147c..88ef0d91d7d64 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -453,7 +453,7 @@ if (NOT onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_test_providers_src ${orttraining_test_trainingops_cpu_src}) - if (onnxruntime_USE_CUDA OR onnxruntime_USE_ROCM) + if (onnxruntime_USE_CUDA) file(GLOB_RECURSE orttraining_test_trainingops_cuda_src CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/test/training_ops/cuda/*" ) @@ -613,10 +613,6 @@ if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() -if(onnxruntime_USE_ROCM) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_rocm) -endif() - if(onnxruntime_USE_COREML) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) endif() @@ -797,9 +793,6 @@ endif() if (onnxruntime_USE_NCCL) target_include_directories(onnxruntime_test_utils PRIVATE ${NCCL_INCLUDE_DIRS}) endif() -if (onnxruntime_USE_ROCM) - target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) -endif() onnxruntime_add_include_to_target(onnxruntime_test_utils onnxruntime_common onnxruntime_framework onnxruntime_session GTest::gtest GTest::gmock onnx onnx_proto flatbuffers::flatbuffers nlohmann_json::nlohmann_json Boost::mp11 safeint_interface Eigen3::Eigen) if (onnxruntime_USE_DML) target_add_dml(onnxruntime_test_utils) @@ -916,12 +909,6 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR IOS) "${TEST_SRC_DIR}/providers/cpu/model_tests.cc" ) endif() -if (USE_ROCM) - # The following unit test takes about 40 minutes. - list(REMOVE_ITEM all_tests - "${TEST_SRC_DIR}/contrib_ops/matmul_4bits_test.cc" - ) -endif() set(test_all_args) if (onnxruntime_USE_TENSORRT OR onnxruntime_USE_NV) @@ -1003,16 +990,7 @@ endif() if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS) target_compile_definitions(onnxruntime_test_all PRIVATE DEBUG_NODE_INPUTS_OUTPUTS) endif() -if (onnxruntime_USE_ROCM) - if (onnxruntime_USE_COMPOSABLE_KERNEL) - target_compile_definitions(onnxruntime_test_all PRIVATE USE_COMPOSABLE_KERNEL) - if (onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE) - target_compile_definitions(onnxruntime_test_all PRIVATE USE_COMPOSABLE_KERNEL_CK_TILE) - endif() - endif() - target_compile_options(onnxruntime_test_all PRIVATE -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1) - target_include_directories(onnxruntime_test_all PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) -endif() + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(onnxruntime_test_all PRIVATE Python::Python) endif() @@ -1140,9 +1118,7 @@ if (NOT IOS) target_link_libraries(onnx_test_runner PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs} nlohmann_json::nlohmann_json) target_include_directories(onnx_test_runner PRIVATE ${ONNXRUNTIME_ROOT}) - if (onnxruntime_USE_ROCM) - target_include_directories(onnx_test_runner PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) - endif() + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(onnx_test_runner PRIVATE Python::Python) endif() @@ -1263,9 +1239,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_perf_test PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir} ${CMAKE_CURRENT_BINARY_DIR}) - if (onnxruntime_USE_ROCM) - target_include_directories(onnxruntime_perf_test PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) - endif() + if (WIN32) target_compile_options(onnxruntime_perf_test PRIVATE ${disabled_warnings}) if (NOT DEFINED SYS_PATH_LIB) @@ -1371,9 +1345,7 @@ endif() if (onnxruntime_USE_CUDA) list(APPEND onnxruntime_shared_lib_test_LIBS) endif() - if (onnxruntime_USE_ROCM) - list(APPEND onnxruntime_shared_lib_test_LIBS hip::host) - endif() + if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) endif() @@ -1407,10 +1379,7 @@ endif() if (onnxruntime_USE_NV) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) endif() - if (onnxruntime_USE_ROCM) - target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include) - target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__) - endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE @@ -1467,10 +1436,7 @@ endif() DEPENDS ${all_dependencies} ) - if (onnxruntime_USE_ROCM) - target_include_directories(onnxruntime_test_debug_node_inputs_outputs PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hipcub/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include) - target_include_directories(onnxruntime_test_debug_node_inputs_outputs PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime) - endif(onnxruntime_USE_ROCM) + target_compile_definitions(onnxruntime_test_debug_node_inputs_outputs PRIVATE DEBUG_NODE_INPUTS_OUTPUTS) @@ -1506,6 +1472,8 @@ endif() "$<$>:/wd6326>") target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /bigobj>" + "$<$>:/bigobj>") endif() if(IOS) set_target_properties(onnxruntime_mlas_test PROPERTIES @@ -1616,14 +1584,6 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() endif() - if (onnxruntime_USE_ROCM) - list(APPEND custom_op_src_patterns - "${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/rocm_ops.hip" - "${TEST_SRC_DIR}/testdata/custom_op_library/rocm/rocm_ops.*") - list(APPEND custom_op_lib_include ${onnxruntime_ROCM_HOME}/include) - list(APPEND custom_op_lib_option "-D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1") - endif() - file(GLOB custom_op_src ${custom_op_src_patterns}) onnxruntime_add_shared_library(custom_op_library ${custom_op_src}) target_compile_options(custom_op_library PRIVATE ${custom_op_lib_option}) @@ -1840,10 +1800,9 @@ endif() if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD) - onnxruntime_add_shared_library_module(example_plugin_ep - ${TEST_SRC_DIR}/autoep/library/example_plugin_ep_utils.h - ${TEST_SRC_DIR}/autoep/library/example_plugin_ep_utils.cc - ${TEST_SRC_DIR}/autoep/library/example_plugin_ep.cc) + file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h" + "${TEST_SRC_DIR}/autoep/library/*.cc") + onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src}) target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) target_link_libraries(example_plugin_ep PRIVATE onnxruntime) @@ -1950,8 +1909,7 @@ endif() if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS|visionOS|tvOS" AND NOT CMAKE_SYSTEM_NAME STREQUAL "Android" - AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" - AND NOT onnxruntime_USE_ROCM) + AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") file(GLOB_RECURSE test_execution_provider_srcs "${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/*.h" "${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/*.cc" diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 486439a68b7ff..ffe866164a411 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -224,113 +224,11 @@ else() if (onnxruntime_USE_WEBGPU) string(APPEND EXPORTED_FUNCTIONS ",_wgpuBufferRelease,_wgpuCreateInstance") endif() - - if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - set(MAXIMUM_MEMORY "17179869184") - target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:-s MEMORY64=1" - ) - string(APPEND CMAKE_C_FLAGS " -sMEMORY64 -Wno-experimental") - string(APPEND CMAKE_CXX_FLAGS " -sMEMORY64 -Wno-experimental") - set(SMEMORY_FLAG "-sMEMORY64") - - target_compile_options(onnx PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(onnxruntime_common PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(onnxruntime_session PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(onnxruntime_framework PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(onnx_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - # target_compile_options(protoc PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(libprotobuf-lite PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(onnxruntime_providers PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(onnxruntime_mlas PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(onnxruntime_graph PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(onnxruntime_flatbuffers PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(onnxruntime_util PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(re2 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_flags_private_handle_accessor PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_flags_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_flags_commandlineflag PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_flags_commandlineflag_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_flags_marshalling PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_flags_reflection PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_flags_config PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_flags_program_name PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_cord PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_cordz_info PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_cord_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_cordz_functions PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_cordz_handle PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_crc_cord_state PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_crc32c PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_crc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_crc_cpu_detect PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_raw_hash_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_hashtablez_sampler PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_exponential_biased PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_internal_conditions PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_internal_check_op PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_internal_message PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_internal_format PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_str_format_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_internal_log_sink_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_internal_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_sink PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_entry PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_city PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_low_level_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_bad_variant_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_vlog_config_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_synchronization PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_kernel_timeout_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_time_zone PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_civil_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_graphcycles_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_bad_optional_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_internal_fnmatch PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_examine_stack PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_symbolize PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_malloc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_demangle_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_demangle_rust PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_decode_rust_punycode PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_utf8_for_code_point PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_stacktrace PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_debugging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_internal_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_strerror PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_internal_nullguard PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_strings PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_strings_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_int128 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_string_view PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_base PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_spinlock_wait PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_throw_delegate PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_raw_logging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_log_severity PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - if (onnxruntime_USE_EXTENSIONS) - target_compile_options(ortcustomops PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(ocos_operators PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(noexcep_operators PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - endif() - target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js\"" - ) - list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js") - else () - set(MAXIMUM_MEMORY "4294967296") - target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/js_post_js.js\"" - ) - list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js") - endif () - + set(MAXIMUM_MEMORY "4294967296") + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/js_post_js.js\"" + ) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js") target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]" "SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}" @@ -347,42 +245,7 @@ else() --no-entry "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\"" ) - if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - set(SIGNATURE_CONVERSIONS "OrtRun:_pppppppp,\ -OrtRunWithBinding:_ppppp,\ -OrtGetTensorData:_ppppp,\ -OrtCreateTensor:p_pppp_,\ -OrtCreateSession:pppp,\ -OrtReleaseSession:_p,\ -OrtGetInputOutputCount:_ppp,\ -OrtCreateSessionOptions:pp__p_ppppp,\ -OrtReleaseSessionOptions:_p,\ -OrtAppendExecutionProvider:_pp,\ -OrtAddSessionConfigEntry:_ppp,\ -OrtGetInputName:ppp,\ -OrtGetOutputName:ppp,\ -OrtCreateRunOptions:ppp_p,\ -OrtReleaseRunOptions:_p,\ -OrtReleaseTensor:_p,\ -OrtFree:_p,\ -OrtCreateBinding:_p,\ -OrtBindInput:_ppp,\ -OrtBindOutput:_ppp_,\ -OrtClearBoundOutputs:_p,\ -OrtReleaseBinding:_p,\ -OrtGetLastError:_pp,\ -JsepOutput:pp_p,\ -JsepGetNodeName:pp,\ -JsepOutput:pp_p,\ -jsepCopy:_pp_,\ -jsepCopyAsync:_pp_,\ -jsepDownload:_pp_") - target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" - "SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'" - ) - endif () - + if (onnxruntime_USE_JSEP) # NOTE: "-s ASYNCIFY=1" is required for JSEP to work with WebGPU # This flag allows async functions to be called from sync functions, in the cost of binary size and @@ -393,13 +256,7 @@ jsepDownload:_pp_") "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" ) list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js") - - if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" - "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']" - ) - endif() + endif() if (onnxruntime_USE_WEBGPU) @@ -469,9 +326,7 @@ jsepDownload:_pp_") endif() # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. - if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") - endif() + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") if (onnxruntime_ENABLE_WEBASSEMBLY_PROFILING) target_link_options(onnxruntime_webassembly PRIVATE --profiling --profiling-funcs) diff --git a/cmake/patches/dawn/dawn_binskim.patch b/cmake/patches/dawn/dawn_binskim.patch new file mode 100644 index 0000000000000..213eb644dc5a2 --- /dev/null +++ b/cmake/patches/dawn/dawn_binskim.patch @@ -0,0 +1,13 @@ +diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt +index 0d5f52a358..90764d21b5 100644 +--- a/src/tint/CMakeLists.txt ++++ b/src/tint/CMakeLists.txt +@@ -143,8 +143,6 @@ function(tint_default_compile_options TARGET) + /W4 + /wd4068 # unknown pragma + /wd4127 # conditional expression is constant +- /wd4244 # 'conversion' conversion from 'type1' to 'type2', possible loss of data +- /wd4267 # 'var' : conversion from 'size_t' to 'type', possible loss of data + /wd4324 # 'struct_name' : structure was padded due to __declspec(align()) + /wd4459 # declaration of 'identifier' hides global declaration + /wd4458 # declaration of 'identifier' hides class member diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 30d5a44a1d1cc..f51370212ff5a 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 8b5af303..7fe05a5a 100644 +index 8b5af303..8593fe4a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ option(ONNX_USE_LITE_PROTO "Use lite protobuf instead of full." OFF) @@ -47,15 +47,7 @@ index 8b5af303..7fe05a5a 100644 add_library(onnx_proto ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS}) add_dependencies(onnx_proto gen_onnx_operators_proto gen_onnx_data_proto) -@@ -492,6 +507,7 @@ if(MSVC) - endif() - else() - # On non-Windows, hide all symbols we don't need -+ set(EXTRA_FLAGS "-Wno-unused-parameter") - set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)") - set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden) - set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1) -@@ -595,13 +611,6 @@ if(ONNX_BUILD_PYTHON) +@@ -595,13 +610,6 @@ if(ONNX_BUILD_PYTHON) target_link_libraries(onnx_cpp2py_export PRIVATE ${Python3_LIBRARIES}) target_compile_options(onnx_cpp2py_export PRIVATE /MP @@ -69,7 +61,7 @@ index 8b5af303..7fe05a5a 100644 ${EXTRA_FLAGS}) add_msvc_runtime_flag(onnx_cpp2py_export) add_onnx_global_defines(onnx_cpp2py_export) -@@ -618,23 +627,9 @@ endif() +@@ -618,23 +626,9 @@ endif() if(MSVC) target_compile_options(onnx_proto PRIVATE /MP @@ -165,38 +157,3 @@ index acf3aac7..5bef6e72 100644 OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema); } static void -diff --git a/onnx/onnx_pb.h b/onnx/onnx_pb.h -index 0aab3e26..27f32195 100644 ---- a/onnx/onnx_pb.h -+++ b/onnx/onnx_pb.h -@@ -47,10 +47,30 @@ - #define ONNX_API ONNX_IMPORT - #endif - -+#if defined(__GNUC__) -+#pragma GCC diagnostic push -+ -+// In file included from onnx/onnx-ml.pb.h:30: -+// In file included from google/protobuf/extension_set.h:53: -+// google/protobuf/parse_context.h:328:47: error: implicit conversion loses integer precision: 'long' to 'int' [-Werror,-Wshorten-64-to-32] -+#if defined(__has_warning) -+#if __has_warning("-Wshorten-64-to-32") -+#pragma GCC diagnostic ignored "-Wshorten-64-to-32" -+#endif -+#endif // defined(__has_warning) -+ -+#endif // defined(__GNUC__) -+ -+ - #ifdef ONNX_ML - #include "onnx/onnx-ml.pb.h" - #else - #include "onnx/onnx.pb.h" - #endif - -+#if defined(__GNUC__) -+#pragma GCC diagnostic pop -+#endif -+ -+ - #endif // ! ONNX_ONNX_PB_H diff --git a/cmake/utils/detect_cuda_arch.cu b/cmake/utils/detect_cuda_arch.cu index 83fbc13dbff7f..52a51697326ff 100644 --- a/cmake/utils/detect_cuda_arch.cu +++ b/cmake/utils/detect_cuda_arch.cu @@ -4,36 +4,30 @@ #include #include -int main(int argc, char* argv[]) -{ - int n_devices = 0; - int rc = cudaGetDeviceCount(&n_devices); - if (rc != cudaSuccess) - { - cudaError_t error = cudaGetLastError(); - std::cout << "CUDA error: " << cudaGetErrorString(error) << std::endl; - return rc; - } +int main(int argc, char* argv[]) { + int n_devices = 0; + int rc = cudaGetDeviceCount(&n_devices); + if (rc != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + std::cout << "CUDA error: " << cudaGetErrorString(error) << std::endl; + return rc; + } - std::vector> arch(n_devices); - for (int cd = 0; cd < n_devices; ++cd) - { - cudaDeviceProp dev; - int rc = cudaGetDeviceProperties(&dev, cd); - if (rc != cudaSuccess) - { - cudaError_t error = cudaGetLastError(); - std::cout << "CUDA error: " << cudaGetErrorString(error) << std::endl; - return rc; - } - else - { - arch[cd] = {dev.major, dev.minor}; - } + std::vector> arch(n_devices); + for (int cd = 0; cd < n_devices; ++cd) { + cudaDeviceProp dev; + int rc = cudaGetDeviceProperties(&dev, cd); + if (rc != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + std::cout << "CUDA error: " << cudaGetErrorString(error) << std::endl; + return rc; + } else { + arch[cd] = {dev.major, dev.minor}; } + } - std::pair best_cc = *std::max_element(begin(arch), end(arch)); - std::cout << best_cc.first << best_cc.second; + std::pair best_cc = *std::max_element(begin(arch), end(arch)); + std::cout << best_cc.first << best_cc.second; - return 0; + return 0; } diff --git a/cmake/vcpkg-configuration.json b/cmake/vcpkg-configuration.json index 54696dc9f2c82..131f6ec779d49 100644 --- a/cmake/vcpkg-configuration.json +++ b/cmake/vcpkg-configuration.json @@ -2,10 +2,10 @@ "default-registry": { "kind": "git", "repository": "https://github.com/Microsoft/vcpkg", - "baseline": "ce613c41372b23b1f51333815feb3edd87ef8a8b" + "baseline": "ef7dbf94b9198bc58f45951adcf1f041fcbc5ea0" }, "overlay-ports": [ - "./vcpkg-ports" + "./vcpkg-ports" ], "registries": [] } diff --git a/cmake/vcpkg-ports/onnx/binskim.patch b/cmake/vcpkg-ports/onnx/binskim.patch index 30d5a44a1d1cc..f51370212ff5a 100644 --- a/cmake/vcpkg-ports/onnx/binskim.patch +++ b/cmake/vcpkg-ports/onnx/binskim.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 8b5af303..7fe05a5a 100644 +index 8b5af303..8593fe4a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ option(ONNX_USE_LITE_PROTO "Use lite protobuf instead of full." OFF) @@ -47,15 +47,7 @@ index 8b5af303..7fe05a5a 100644 add_library(onnx_proto ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS}) add_dependencies(onnx_proto gen_onnx_operators_proto gen_onnx_data_proto) -@@ -492,6 +507,7 @@ if(MSVC) - endif() - else() - # On non-Windows, hide all symbols we don't need -+ set(EXTRA_FLAGS "-Wno-unused-parameter") - set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)") - set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden) - set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1) -@@ -595,13 +611,6 @@ if(ONNX_BUILD_PYTHON) +@@ -595,13 +610,6 @@ if(ONNX_BUILD_PYTHON) target_link_libraries(onnx_cpp2py_export PRIVATE ${Python3_LIBRARIES}) target_compile_options(onnx_cpp2py_export PRIVATE /MP @@ -69,7 +61,7 @@ index 8b5af303..7fe05a5a 100644 ${EXTRA_FLAGS}) add_msvc_runtime_flag(onnx_cpp2py_export) add_onnx_global_defines(onnx_cpp2py_export) -@@ -618,23 +627,9 @@ endif() +@@ -618,23 +626,9 @@ endif() if(MSVC) target_compile_options(onnx_proto PRIVATE /MP @@ -165,38 +157,3 @@ index acf3aac7..5bef6e72 100644 OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema); } static void -diff --git a/onnx/onnx_pb.h b/onnx/onnx_pb.h -index 0aab3e26..27f32195 100644 ---- a/onnx/onnx_pb.h -+++ b/onnx/onnx_pb.h -@@ -47,10 +47,30 @@ - #define ONNX_API ONNX_IMPORT - #endif - -+#if defined(__GNUC__) -+#pragma GCC diagnostic push -+ -+// In file included from onnx/onnx-ml.pb.h:30: -+// In file included from google/protobuf/extension_set.h:53: -+// google/protobuf/parse_context.h:328:47: error: implicit conversion loses integer precision: 'long' to 'int' [-Werror,-Wshorten-64-to-32] -+#if defined(__has_warning) -+#if __has_warning("-Wshorten-64-to-32") -+#pragma GCC diagnostic ignored "-Wshorten-64-to-32" -+#endif -+#endif // defined(__has_warning) -+ -+#endif // defined(__GNUC__) -+ -+ - #ifdef ONNX_ML - #include "onnx/onnx-ml.pb.h" - #else - #include "onnx/onnx.pb.h" - #endif - -+#if defined(__GNUC__) -+#pragma GCC diagnostic pop -+#endif -+ -+ - #endif // ! ONNX_ONNX_PB_H diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 4544b0daf93cd..bfb6e7c38ccb4 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -946,7 +946,7 @@ Do not modify directly.* |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(float16), tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(bfloat16), tensor(float), tensor(float16), tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 0660cc874ffb7..6f519249b98b6 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -7,6 +7,7 @@ #include "core/common/common.h" #include "core/framework/allocator_stats.h" +#include "core/session/abi_key_value_pairs.h" // some enums are defined in session/onnxruntime_c_api.h but used in ortdevice.h/ortmemory.h #include "core/session/onnxruntime_c_api.h" #include "core/framework/ortdevice.h" @@ -37,6 +38,26 @@ struct OrtArenaCfg { int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default int initial_growth_chunk_size_bytes; // use -1 to allow ORT to choose the default int64_t max_power_of_two_extend_bytes; // use -1 to allow ORT to choose the default + + bool IsValid() { + return arena_extend_strategy >= -1 && arena_extend_strategy <= 1 && + initial_chunk_size_bytes >= -1 && + max_dead_bytes_per_chunk >= -1 && + initial_growth_chunk_size_bytes >= -1 && + max_power_of_two_extend_bytes >= -1; + } + + // config key names that we parse in FromKeyValuePairs + struct ConfigKeyNames { + static constexpr const char* ArenaExtendStrategy = "arena.extend_strategy"; + static constexpr const char* InitialChunkSizeBytes = "arena.initial_chunk_size_bytes"; + static constexpr const char* MaxDeadBytesPerChunk = "arena.max_dead_bytes_per_chunk"; + static constexpr const char* InitialGrowthChunkSizeBytes = "arena.initial_growth_chunk_size_bytes"; + static constexpr const char* MaxPowerOfTwoExtendBytes = "arena.max_power_of_two_extend_bytes"; + static constexpr const char* MaxMem = "arena.max_mem"; + }; + + static onnxruntime::common::Status FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& cfg); }; namespace onnxruntime { diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 30a5735c4e493..65a6ec304bda2 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -5,6 +5,8 @@ #ifndef SHARED_PROVIDER #include +#include +#include #include #include @@ -35,6 +37,7 @@ class GraphOptimizerRegistry; #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" +struct OrtEpDevice; struct OrtRunOptions; namespace onnxruntime { @@ -62,7 +65,9 @@ using RunOptions = ::OrtRunOptions; enum class DataLayout { NCHW, NHWC, - NCHWC, + + // NCHW is the default ONNX standard data layout. So default to it. + Default = NCHW, }; class IExecutionProvider { @@ -323,9 +328,21 @@ class IExecutionProvider { } virtual DataLayout GetPreferredLayout() const { - // NCHW is the default ONNX standard data layout. So default to it. // EPs which prefer a different layout should override to return their preferred layout. - return DataLayout::NCHW; + return DataLayout::Default; + } + + /** + Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout should be + converted to `target_data_layout`. + If the EP prefers a non-default data layout (see `GetPreferredLayout()`), this function will be called during + layout transformation with `target_data_layout` set to the EP's preferred data layout. + A return value of `std::nullopt` indicates that this decision is left to ORT. + */ + virtual std::optional ShouldConvertDataLayoutForOp(std::string_view /*domain*/, + std::string_view /*op_type*/, + DataLayout /*target_data_layout*/) const { + return std::nullopt; } virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/, AllocatorMap&) const {} diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h index fab65e8fee692..e63ab044834f5 100644 --- a/include/onnxruntime/core/framework/run_options.h +++ b/include/onnxruntime/core/framework/run_options.h @@ -43,7 +43,8 @@ struct OrtRunOptions { #endif // Stores the configurations for this run - // To add an configuration to this specific run, call OrtApis::AddRunConfigEntry + // To add a configuration value to this specific run, call OrtApis::AddRunConfigEntry + // To get a configuration value, call OrtApis::GetRunConfigEntry // The configuration keys and value formats are defined in // /include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h onnxruntime::ConfigOptions config_options; diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 6883d3ef644d8..54e03a31fceef 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -817,15 +817,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi /// /// Returns the initializer's TensorProto if 'name' is an initializer (either constant or overridable). - /// If the initializer is not found, a nullptr is returned. An output parameter is set to true if the initializer - /// is constant. + /// If the initializer is not found, a nullptr is returned. Also returns, via output parameters, the + /// OrtValue that holds the actual data (if any) and a boolean that indicates if the initializer is constant. /// /// The initializer's name. - /// Checks outer scope if set to true and the graph is a subgraph. + /// Output OrtValue that is populated if the initializer tensor proto has been configured + /// to use external data that points to an OrtValue. Is set to an unallocated OrtValue for + /// a tensor proto that holds the weight data (e.g., small initializers). /// Output parameter set to true if the initializer is a constant. + /// Checks outer scope if set to true and the graph is a subgraph. /// The initializer's TensorProto or nullptr. - const ONNX_NAMESPACE::TensorProto* GetInitializer(const std::string& name, bool check_outer_scope, - bool& is_constant) const; + const ONNX_NAMESPACE::TensorProto* GetInitializer(const std::string& name, OrtValue& value, + bool& is_constant, bool check_outer_scope = false) const; /** Gets the Graph inputs excluding initializers. These are the required inputs to the Graph as the initializers can be optionally overridden via graph inputs. diff --git a/include/onnxruntime/core/graph/model_saving_options.h b/include/onnxruntime/core/graph/model_saving_options.h index 45536a6967606..6c041ec96a035 100644 --- a/include/onnxruntime/core/graph/model_saving_options.h +++ b/include/onnxruntime/core/graph/model_saving_options.h @@ -39,6 +39,9 @@ struct ModelSavingOptions { #else int64_t allocation_granularity = 4096; #endif + // Force embed all external initializer into the Onnx file + // Used for EPContext model generation while some nodes fallback on CPU which has external data dependency + bool force_embed_external_ini = false; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 1d5b2d5513044..9f3a9eeabff6b 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -27,8 +27,13 @@ namespace onnxruntime { class EpFactoryInternal; class InferenceSession; struct IExecutionProviderFactory; +struct OrtAllocatorImplWrappingIAllocator; struct SessionOptions; +namespace plugin_ep { +class DataTransfer; +} // namespace plugin_ep + /** Provides the runtime environment for onnxruntime. Create one instance for the duration of execution. @@ -85,7 +90,7 @@ class Environment { * Registers an allocator for sharing between multiple sessions. * Return an error if an allocator with the same OrtMemoryInfo is already registered. */ - Status RegisterAllocator(AllocatorPtr allocator); + Status RegisterAllocator(OrtAllocator* allocator); /** * Creates and registers an allocator for sharing between multiple sessions. @@ -130,7 +135,16 @@ class Environment { const std::vector& GetOrtEpDevices() const { return execution_devices_; } + + Status CreateSharedAllocator(const OrtEpDevice& ep_device, + OrtDeviceMemoryType mem_type, OrtAllocatorType allocator_type, + const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator); + Status ReleaseSharedAllocator(const OrtEpDevice& ep_device, OrtDeviceMemoryType mem_type); #endif // !defined(ORT_MINIMAL_BUILD) + + // return a shared allocator from a plugin EP or custom allocator added with RegisterAllocator + Status GetSharedAllocator(const OrtMemoryInfo& mem_info, OrtAllocator*& allocator); + ~Environment(); private: @@ -140,12 +154,37 @@ class Environment { const OrtThreadingOptions* tp_options = nullptr, bool create_global_thread_pools = false); + Status RegisterAllocatorImpl(AllocatorPtr allocator); + Status UnregisterAllocatorImpl(const OrtMemoryInfo& mem_info, bool error_if_not_found = true); + Status CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, + const OrtMemoryInfo& memory_info, OrtAllocatorType allocator_type, + const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator, + bool replace_existing); + std::unique_ptr logging_manager_; std::unique_ptr intra_op_thread_pool_; std::unique_ptr inter_op_thread_pool_; bool create_global_thread_pools_{false}; + + std::mutex mutex_; + + // shared allocators from various sources. + // CreateAndRegisterAllocator[V2]: IAllocator allocators created by ORT + // RegisterAllocator: IAllocatorImplWrappingOrtAllocator custom allocators registered by the user. + // TODO: How can we detect registration of an allocator from an InferenceSession? + // OrtEpDevice: We create a default shared IAllocatorImplWrappingOrtAllocator for each OrtEpDevice memory info. std::vector shared_allocators_; + // RegisterAllocator and CreateSharedAllocator pointers. Used for GetSharedAllocator. + // Every instance here is also in shared_allocators_. + std::unordered_set shared_ort_allocators_; + + // OrtAllocator wrapped CPUAllocator::DefaultInstance that is returned by GetSharedAllocator when no plugin EP is + // providing a CPU allocator. + std::unique_ptr default_cpu_ort_allocator_; + + using OrtAllocatorUniquePtr = std::unique_ptr>; + #if !defined(ORT_MINIMAL_BUILD) // register EPs that are built into the ORT binary so they can take part in AutoEP selection // added to ep_libraries diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c860e0794abed..9106cd94ad031 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -52,8 +52,9 @@ extern "C" { #define _In_opt_ #define _In_opt_z_ #define _Out_ -#define _Outptr_ #define _Out_opt_ +#define _Outptr_ +#define _Outptr_opt_ #define _Inout_ #define _Inout_opt_ #define _Frees_ptr_opt_ @@ -61,6 +62,7 @@ extern "C" { #define _Ret_notnull_ #define _Check_return_ #define _Outptr_result_maybenull_ +#define _Outptr_result_maybenull_z_ #define _In_reads_(X) #define _Inout_updates_(X) #define _Out_writes_(X) @@ -497,6 +499,7 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e typedef enum OrtTypeTag { ORT_TYPE_TAG_Void, ORT_TYPE_TAG_OrtValueInfo, + ORT_TYPE_TAG_OrtOpAttr, ORT_TYPE_TAG_OrtNode, ORT_TYPE_TAG_OrtGraph, } OrtTypeTag; @@ -795,9 +798,6 @@ typedef struct OrtCompileApi OrtCompileApi; struct OrtEpApi; typedef struct OrtEpApi OrtEpApi; -struct OrtNodeComputeInfo; -typedef struct OrtNodeComputeInfo OrtNodeComputeInfo; - /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase @@ -4583,7 +4583,8 @@ struct OrtApi { * \param[in] provider_options_values value of the provider options map * \param[in] num_keys Length of the provider options map */ - ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg, + ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, + _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); /** \brief Run the model asynchronously in a thread owned by intra op thread pool @@ -4963,6 +4964,8 @@ struct OrtApi { ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); + /// @} + /** \brief Release an OrtValueInfo instance if it was not added to an OrtGraph. * \since Version 1.22. */ @@ -5879,6 +5882,52 @@ struct OrtApi { */ ORT_API2_STATUS(Node_GetImplicitInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** implicit_inputs); + /** \brief Returns a node's attributes as OrtOpAttr instances. + * + * \param[in] node The OrtNode instance. + * \param[out] attributes Output parameter set to the OrtArrayOfConstObjects instance containing the node's attributes + * as OrtOpAttr instances. Must be released by calling ReleaseArrayOfConstObjects. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attributes); + + /** \brief Gets the OrtNode's attribute as OrtOpAttr by name. + * + * \param[in] node The OrtNode instance. + * \param[in] attribute_name The name of the attribute + * \param[out] attribute Output the attribute if its name matches 'attribute_name', otherwise output nullptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute); + + /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. + * + * \param[in] attribute The OrtOpAttr instance. + * \param[out] type Output the attribute type as OrtOpAttrType. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); + + /** \brief Get the attribute name from an OrtOpAttr. + * + * \param[in] attribute The OrtOpAttr instance. + * \param[out] name Output parameter set to the attribute's name. The name is a null-terminated string. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); + /** \brief Get the subgraphs, as OrtGraph instances, contained by the given node. * * Certain operator types (e.g., If and Loop) contain nested subgraphs. @@ -5907,6 +5956,117 @@ struct OrtApi { */ ORT_API2_STATUS(Node_GetParentGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** parent_graph); + + /// \name OrtRunOptions + /// @{ + + /** \brief Get a run configuration entry. + * + * If a run configuration entry with key `config_key` doesn't exist, `config_value` will be set to NULL. + * + * `config_key`s are defined in onnxruntime_run_options_config_keys.h. + * + * \param[in] options The OrtRunOptions instance. + * \param[in] config_key The configuration entry key. A null-terminated string. + * \param[out] config_value Output parameter set to the configuration entry value. Either a null-terminated string if + * a configuration entry exists or NULL otherwise. + * Do not free this value. It is owned by `options` and will be invalidated if another call + * to `AddRunConfigEntry()` overwrites it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetRunConfigEntry, _In_ const OrtRunOptions* options, + _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value); + + /// @} + + /** \brief Get the OrtMemoryInfo for the device. + * + * \param[in] ep_device The OrtEpDevice instance to query. + * \return A pointer to the OrtMemoryInfo for the device. + * + * \since Version 1.23 + */ + ORT_API_T(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device); + + /** \brief Create/replace a shared allocator for the OrtEpDevice in the OrtEnv. + * + * OrtEpDevice maps to the EP factory, and the factory provides the allocator implementation. + * + * Both OrtDeviceMemoryType_DEFAULT and OrtDeviceMemoryType_HOST_ACCESSIBLE are optional for an EP to provide. + * It is EP implementation dependent as to what is available. + * + * If a shared allocator already exists for the OrtEpDevice and OrtDeviceMemoryType, it is replaced. This allows + * changing the shared allocator configuration from the default. e.g. adding an arena. + * + * \param[in] env The OrtEnv instance to create the shared allocator in. + * \param[in] ep_device The OrtEpDevice instance to create the shared allocator for. + * \param[in] mem_type The memory type to use for the shared allocator. + * \param[in] allocator_type The type of allocator to create (e.g. OrtAllocatorType::OrtArenaAllocator). + * \param[in] allocator_options Optional key-value pairs to configure the allocator. If arena based, see + * include/onnxruntime/core/framework/allocator.h for the keys and values that can be + * used. + * \param[out] allocator A pointer to the created shared allocator. Owned by the OrtEnv instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(CreateSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDevice* ep_device, + _In_ OrtDeviceMemoryType mem_type, _In_ OrtAllocatorType allocator_type, + _In_opt_ const OrtKeyValuePairs* allocator_options, + _Outptr_opt_ OrtAllocator** allocator); + + /** \brief Get a shared allocator from the OrtEnv. + * + * By default there is a shared allocator created for all OrtEpDevice instances, so if you get the OrtMemoryInfo + * from the OrtEpDevice using EpDevice_MemoryInfo a shared allocator is guaranteed to exist. + * + * This will also match and return custom allocators added with RegisterAllocator. + * + * It is not an error to not find a matching allocator. + * + * \param[in] env The OrtEnv instance to get the shared allocator from. + * \param[in] mem_info The OrtMemoryInfo instance to get the shared allocator for. + * \param[out] allocator A pointer to the shared allocator, or nullptr if no shared allocator exists for + * the given memory info. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(GetSharedAllocator, _In_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, + _Outptr_result_maybenull_ OrtAllocator** allocator); + + /** \brief Release a shared allocator from the OrtEnv for the OrtEpDevice and memory type. + * + * This will release the shared allocator for the given OrtEpDevice and memory type. + * If no shared allocator exists, this is a no-op. + * + * \param[in] env The OrtEnv instance to release the shared allocator from. + * \param[in] ep_device The OrtEpDevice instance to release the shared allocator for. + * \param[in] mem_type The memory type of the shared allocator to release. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(ReleaseSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDevice* ep_device, + _In_ OrtDeviceMemoryType mem_type); + + /** \brief Get a const pointer to the raw data inside a tensor + * + * Used to read the internal tensor data directly. + * \note The returned pointer is valid until the \p value is destroyed. + * + * \param[in] value A tensor type (string tensors are not supported) + * \param[out] out Filled in with a pointer to the internal storage + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 08e8736e9e591..c59baa59c91a5 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -942,6 +942,7 @@ struct RunOptions : detail::Base { const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry + const char* GetConfigEntry(const char* config_key); ///< Wraps OrtApi::GetRunConfigEntry /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance * diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 25936038ba297..612adc81d3309 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -767,6 +767,12 @@ inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char return *this; } +inline const char* RunOptions::GetConfigEntry(const char* config_key) { + const char* out{}; + ThrowOnError(GetApi().GetRunConfigEntry(p_, config_key, &out)); + return out; +} + inline RunOptions& RunOptions::SetTerminate() { ThrowOnError(GetApi().RunOptionsSetTerminate(p_)); return *this; diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 68b6992177b0d..c53a2f42247d9 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -10,8 +10,101 @@ extern "C" { ORT_RUNTIME_CLASS(Ep); ORT_RUNTIME_CLASS(EpFactory); ORT_RUNTIME_CLASS(EpGraphSupportInfo); +ORT_RUNTIME_CLASS(MemoryDevice); // opaque class to wrap onnxruntime::OrtDevice ORT_RUNTIME_CLASS(NodeComputeContext); +// Opaque class to create an onnxruntime::Stream. Will be filled out in separate PR. +// Adding here for OrtDataTransferImpl as the stream type is required by the IDataTransfer API. +ORT_RUNTIME_CLASS(SyncStream); + +// struct that an EP implements for IDataTransfer to copy between devices it uses and CPU +typedef struct OrtDataTransferImpl { + uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + + /** \brief Release the OrtDataTransferImpl instance. + * + * This is called by ORT when the OrtDataTransferImpl instance is no longer needed. + * The implementation should release any resources held by the instance. + * + * \param[in] this_ptr Pointer to the OrtDataTransferImpl instance. + * + * \since Version 1.23. + */ + ORT_API_T(void, Release, _In_ void* this_ptr); + + /** \brief Check if the implementation can copy between the source and destination memory devices. + * + * \param[in] this_ptr Pointer to the OrtDataTransferImpl instance. + * \param[in] src_memory_device Source OrtMemoryDevice to copy from. + * \param[in] dst_memory_device Destination OrtMemoryDevice to copy to. + * \return True if the implementation can copy between the devices. + * + * \since Version 1.23. + */ + ORT_API_T(bool, CanCopy, _In_ void* this_ptr, + _In_ const OrtMemoryDevice* src_memory_device, _In_ const OrtMemoryDevice* dst_memory_device); + + /** \brief Copy tensors from src_tensors to dst_tensors using the provided streams. + * + * The implementation can use the provided streams to perform asynchronous copies if supported. + * If a stream is not available, the copy is performed synchronously. + * + * \param[in] this_ptr Pointer to the OrtDataTransferImpl instance. + * \param[in] src_tensors Array of source OrtValue pointers to copy from. + * \param[in] dst_tensors Array of destination OrtValue pointers to copy to. + * \param[in] streams Array of OrtSyncStream pointers for the copy operations, if the execution provider is stream + * aware. nullptr if it is not. + * \param[in] num_tensors Number of tensors to copy. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CopyTensors, _In_ void* this_ptr, + _In_reads_(num_tensors) const OrtValue** src_tensors, + _In_reads_(num_tensors) OrtValue** dst_tensors, + _In_reads_(num_tensors) OrtSyncStream** streams, + _In_ size_t num_tensors); +} OrtDataTransferImpl; + +struct OrtNodeFusionOptions; +typedef struct OrtNodeFusionOptions OrtNodeFusionOptions; + +struct OrtNodeComputeInfo; +typedef struct OrtNodeComputeInfo OrtNodeComputeInfo; + +/** + * \brief The OrtNodeFusionOptions struct specifies options for fusing nodes supported by an execution provider. + * + * Refer to OrtEpApi::EpGraphSupportInfo_AddNodesToFuse. + * + * \since Version 1.23. + */ +struct OrtNodeFusionOptions { + /** \brief The ONNX Runtime version the OrtNodeFusionOptions was compiled with. + * + * Implementation should set to ORT_API_VERSION. + * ORT will use this to ensure it does not use members that were not available when the EP library was compiled. + * + * \since Version 1.23. + */ + uint32_t ort_version_supported; + + /** \brief If set to true, specify that the execution provider does not require ONNX Runtime to provide constant + * initializers as inputs to the fused node during model inference. This is used when the execution + * provider saves a copy of constant initializers, and allows ONNX Runtime to release constant initializers that + * are not used by any execution provider. + * + * If not specified, defaults to false. That is, ONNX Runtime provides constant initializers as inputs to + * the fused node by default. + * + * \since Version 1.23. + */ + bool drop_constant_initializers; + + // const OrtNode* fused_node_schema; +}; + /** * \brief The OrtNodeComputeInfo struct provides functions that an OrtEp implements to specify the compute * function for a compiled OrtGraph instance. @@ -21,7 +114,7 @@ struct OrtNodeComputeInfo { /** \brief The ONNX Runtime version the OrtNodeComputeInfo was compiled with. * * Implementation should set to ORT_API_VERSION. - * ORT will use this to ensure it does not call functions that were not available when the library was compiled. + * ORT will use this to ensure it does not call functions that were not available when the EP library was compiled. * * \since Version 1.23. */ @@ -87,9 +180,6 @@ struct OrtEpApi { ORT_CLASS_RELEASE(EpDevice); /** \brief Specify nodes that are supported by an OrtEp and should be fused into one node. - * - * IMPORTANT: This is not the final version of this API function. This is currently experimental but will - * be stabilized by the ONNX Runtime 1.23 release. * * Because the nodes will be fused into one "fused node", there must not exist an unsupported node in * a path between two of the provided nodes. Otherwise, the graph will become invalid. @@ -100,14 +190,15 @@ struct OrtEpApi { * \param[in] graph_support_info OrtEpGraphSupportInfo instance to which to add the supported nodes. * \param[in] nodes Array of nodes supported by the EP that should be fused/compiled. * \param[in] num_nodes The number of supported nodes. + * \param[in] node_fusion_options Optional node fusion options. Ignored if set to NULL. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info, - _In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes - /*, OrtFusedNodeSchema* optional_fused_node_schema, OrtNodesToOptimizeInfo* nodes_to_opt*/); + _In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes, + _In_opt_ const OrtNodeFusionOptions* node_fusion_options); /** \brief Specify a node that is supported by an OrtEp and should be run with a registered EP kernel. * @@ -133,8 +224,117 @@ struct OrtEpApi { * \since Version 1.23. */ ORT_API_T(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeContext* context); + + /** \brief Register an allocator with the OrtEpDevice. + * + * This allows an EP to provide OrtMemoryInfo for DEFAULT and HOST_ACCESSIBLE memory type as needed. + * The registered values will be used in calls to OrtEpFactory::CreateAllocator to ensure the required allocator/s + * are available for EP usage. + * + * At most one DEFAULT and one HOST_ACCESSIBLE entry can be added. + * Multiple calls for the same memory type will replace a previous entry. + * + * \param[in] ep_device The OrtEpDevice instance to register the OrtMemoryInfo with. + * \param[in] allocator_memory_info The OrtMemoryInfo information for the allocator. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(EpDevice_AddAllocatorInfo, _In_ OrtEpDevice* ep_device, + _In_ const OrtMemoryInfo* allocator_memory_info); + + /** \brief Get the OrtMemoryDevice from an OrtMemoryInfo instance. + * + * This is required for OrtDataTransferImpl (which implements onnxruntime::IDataTransfer) where the OrtMemoryDevice + * is used in the CanCopy and CopyTensors functions. + * + * \param[in] memory_info The OrtMemoryInfo instance to get the memory device from. + * \return The OrtMemoryDevice associated with the OrtMemoryInfo instance. + * + * \since Version 1.23. + */ + ORT_API_T(const OrtMemoryDevice*, MemoryInfo_GetMemoryDevice, _In_ const OrtMemoryInfo* memory_info); + + /** \brief Get the OrtMemoryDevice from an OrtValue instance if it contains a Tensor. + * + * \param[in] value The OrtValue instance to get the memory device from. + * \param[out] device The OrtMemoryDevice associated with the OrtValue instance. + * \return Status Success if OrtValue contains a Tensor. Otherwise, an error status is returned. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Value_GetMemoryDevice, _In_ const OrtValue* value, _Out_ const OrtMemoryDevice** device); + + /** \brief Compare two OrtMemoryDevice instances for equality. + * + * This is used to check if two memory devices are the same. + * Used to implement DataTransferImpl::CanCopy. + * + * \param[in] a The first OrtMemoryDevice instance to compare. + * \param[in] b The second OrtMemoryDevice instance to compare. + * \return True if the two OrtMemoryDevice instances are equal, false otherwise. + * + * \since Version 1.23. + */ + ORT_API_T(bool, MemoryDevice_AreEqual, _In_ const OrtMemoryDevice* a, _In_ const OrtMemoryDevice* b); + + /** \brief Get the OrtMemoryInfoDeviceType value from an OrtMemoryDevice instance. + * + * \param[in] memory_device OrtMemoryDevice instance. + * \return The OrtMemoryInfoDeviceType value. + * + * \since Version 1.23. + */ + ORT_API_T(OrtMemoryInfoDeviceType, MemoryDevice_GetDeviceType, _In_ const OrtMemoryDevice* memory_device); + + /** \brief Get the OrtDeviceMemoryType value from an OrtMemoryDevice instance. + * + * \param[in] memory_device OrtMemoryDevice instance. + * \return The OrtDeviceMemoryType value. + * + * \since Version 1.23. + */ + ORT_API_T(OrtDeviceMemoryType, MemoryDevice_GetMemoryType, _In_ const OrtMemoryDevice* memory_device); + + /** \brief Get the vendor ID from an OrtMemoryDevice instance. + * + * The vendor ID is used to identify the vendor of the device, and is typically set to the PCI vendor ID. + * + * If the device is not vendor specific (e.g. CPU memory) the vendor ID is set to 0. + * + * \param[in] memory_device OrtMemoryDevice instance. + * \return The vendor ID value. + * + * \since Version 1.23. + */ + ORT_API_T(uint32_t, MemoryDevice_GetVendorId, _In_ const OrtMemoryDevice* memory_device); + + /** \brief Get the device ID from an OrtMemoryDevice instance. + * + * \param[in] memory_device OrtMemoryDevice instance. + * \return The device ID. + * + * \since Version 1.23. + */ + ORT_API_T(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_device); }; +/** + * \brief The data layout type. + * + * EPs may specify a preferred data layout type. ORT's default layout type is OrtEpDataLayout_NCHW, or + * OrtEpDataLayout_Default. + * + * \since Version 1.23. + */ +typedef enum OrtEpDataLayout { + OrtEpDataLayout_NCHW = 0, + OrtEpDataLayout_NHWC, + + OrtEpDataLayout_Default = OrtEpDataLayout_NCHW, +} OrtEpDataLayout; + /** * \brief The OrtEp struct provides functions to implement for an execution provider. * \since Version 1.22. @@ -232,6 +432,101 @@ struct OrtEp { void(ORT_API_CALL* ReleaseNodeComputeInfos)(_In_ OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, _In_ size_t num_node_compute_infos); + + /** \brief Get the EP's preferred data layout. + * + * \note Implementation of this function is optional. + * If not implemented, ORT will assume that this EP prefers the data layout `OrtEpDataLayout::NCHW`. + * + * \param[in] this_ptr The OrtEp instance. + * \param[out] preferred_data_layout The EP's preferred data layout. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + OrtStatus*(ORT_API_CALL* GetPreferredDataLayout)(_In_ OrtEp* this_ptr, + _Out_ OrtEpDataLayout* preferred_data_layout); + + /** \brief Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout + * should be converted to `target_data_layout`. + * If the EP prefers a non-default data layout (see `GetPreferredDataLayout()`), this function will be called + * during layout transformation with `target_data_layout` set to the EP's preferred data layout. + * + * \note Implementation of this function is optional. + * If an EP prefers a non-default data layout, it may implement this to customize the specific op data layout + * preferences at a finer granularity. + * + * \param[in] this_ptr The OrtEp instance. + * \param[in] domain The op domain. An empty string means the ONNX domain. + * \param[in] op_type The op type. + * \param[in] target_data_layout The target data layout. + * \param[out] should_convert Whether the associated node's data layout should be converted to `target_data_layout`. + * If greater than 0, convert. + * If 0, don't convert. + * Otherwise, if less than 0, leave the decision to ORT. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + OrtStatus*(ORT_API_CALL* ShouldConvertDataLayoutForOp)(_In_ OrtEp* this_ptr, + _In_z_ const char* domain, + _In_z_ const char* op_type, + _In_ OrtEpDataLayout target_data_layout, + _Outptr_ int* should_convert); + + /** \brief Set dynamic options on this EP. + * + * Dynamic options can be set by the user at any time after session creation with `OrtApi::SetEpDynamicOptions()`. + * + * \param[in] this_ptr The OrtEp instance. + * \param[in] option_keys The dynamic option keys. + * \param[in] option_values The dynamic option values. + * \param[in] num_options The number of dynamic options. + * + * \note Implementation of this function is optional. + * An EP should only implement this if it needs to handle any dynamic options. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + OrtStatus*(ORT_API_CALL* SetDynamicOptions)(_In_ OrtEp* this_ptr, + _In_reads_(num_options) const char* const* option_keys, + _In_reads_(num_options) const char* const* option_values, + _In_ size_t num_options); + + /** \brief Called by ORT to notify the EP of the start of a run. + * + * \param[in] this_ptr The OrtEp instance. + * \param[in] run_options The run options for this run. + * + * \note Implementation of this function is optional. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + OrtStatus*(ORT_API_CALL* OnRunStart)(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options); + + /** \brief Called by ORT to notify the EP of the end of a run. + * + * \param[in] this_ptr The OrtEp instance. + * \param[in] run_options The run options for this run. + * \param[in] sync_stream Whether any associated stream should be synchronized during this call. + * Only applicable if there is such a stream. + * + * \note Implementation of this function is optional. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + OrtStatus*(ORT_API_CALL* OnRunEnd)(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options, + _In_ bool sync_stream); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. @@ -363,6 +658,46 @@ struct OrtEpFactory { * \since Version 1.22. */ void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); + + /** \brief Create an OrtAllocator for the given OrtMemoryInfo. + * + * This is used to create an allocator that an execution provider requires. The factory that creates the EP is + * responsible for providing the required allocators. + * The OrtMemoryInfo instance will match one of the values set in the OrtEpDevice using EpDevice_AddAllocatorInfo. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] memory_info The OrtMemoryInfo to create the allocator for. + * \param[in] allocator_options Optional key-value pairs for allocator options, can be nullptr. + * \param[out] allocator The created OrtAllocator instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateAllocator, _In_ OrtEpFactory* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _In_ const OrtKeyValuePairs* allocator_options, + _Outptr_ OrtAllocator** allocator); + + /** \brief Release an OrtAllocator created by the factory. + * + * \since Version 1.23. + */ + ORT_API_T(void, ReleaseAllocator, _In_ OrtEpFactory* this_ptr, _In_ OrtAllocator* allocator); + + /** \brief Create an OrtDataTransferImpl instance for the factory. + * + * This is used to create an IDataTransfer implementation that can be used to copy data between devices + * that the execution provider supports. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[out] data_transfer The created OrtDataTransferImpl instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateDataTransfer, _In_ OrtEpFactory* this_ptr, _Outptr_ OrtDataTransferImpl** data_transfer); }; #ifdef __cplusplus diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index ae0a06c9e749b..b3368479d2196 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -445,7 +445,7 @@ export const createSession = async ( : { name: nameString, isTensor: true, type: tensorDataTypeEnumToString(elementType), shape: shape! }, ); - if (!BUILD_DEFS.DISABLE_JSEP) { + if (!BUILD_DEFS.DISABLE_JSEP || !BUILD_DEFS.DISABLE_WEBGPU) { if (enableGraphCapture && options?.preferredOutputLocation === undefined) { outputPreferredLocations.push('gpu-buffer'); continue; diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu index a11691d22d8be..e99f380e6fe02 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu @@ -59,10 +59,10 @@ struct OP_QuickGelu : public CtxQuickGelu { #define SPECIALIZED_UNARY_ACTIVATION_IMPL(name, T) \ template void Impl_##name(cudaStream_t stream, const T* input_data, T* output_data, const Ctx##name* func_ctx, size_t count); -#define SPECIALIZED_UNARY_ACTIVATIONL_HFD(name) \ - SPECIALIZED_UNARY_ACTIVATION_IMPL(name, half) \ - SPECIALIZED_UNARY_ACTIVATION_IMPL(name, float) \ - SPECIALIZED_UNARY_ACTIVATION_IMPL(name, double) \ +#define SPECIALIZED_UNARY_ACTIVATIONL_HFD(name) \ + SPECIALIZED_UNARY_ACTIVATION_IMPL(name, half) \ + SPECIALIZED_UNARY_ACTIVATION_IMPL(name, float) \ + SPECIALIZED_UNARY_ACTIVATION_IMPL(name, double) \ SPECIALIZED_UNARY_ACTIVATION_IMPL(name, BFloat16) #define UNARY_ACTIVATION_OP_NAME(name) \ diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 51311715d3b2a..216f101aad4be 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -358,42 +358,42 @@ Status LeanAttention( constexpr bool is_bf16 = false; ORT_RETURN_IF_ERROR(onnxruntime::lean::mha_fwd_kvcache( - device_prop, stream, - data.q, - data.k, // k_cache - data.v, // v_cache - nullptr, // new_k (we have appended new_k to k_cache) - nullptr, // new_v (we have appended new_v to k_cache) - data.output, - reinterpret_cast(data.softmax_lse), - nullptr, // seqlens_k - nullptr, // cos_cache - nullptr, // sin_cache - nullptr, // block_table - parameters.batch_size, - parameters.num_heads, - parameters.num_heads, // num_heads_k - parameters.head_size, - parameters.sequence_length, // seqlen_q - parameters.total_sequence_length, // seqlen_k - 0, // seqlen_k_new - 0, // rotary_dim - scale, // softmax_scale - parameters.is_unidirectional, - is_bf16, - false, // past_bsnh - data.num_splits, - data.grid_dim_z, - data.max_tiles_per_tb, - data.high_load_tbs, - data.tiles_per_head, - reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), - data.lean_sync_flag, - -1, // local_window_size - false, // is_rotary_interleaved - false // is_packed_qkv - )); + device_prop, stream, + data.q, + data.k, // k_cache + data.v, // v_cache + nullptr, // new_k (we have appended new_k to k_cache) + nullptr, // new_v (we have appended new_v to k_cache) + data.output, + reinterpret_cast(data.softmax_lse), + nullptr, // seqlens_k + nullptr, // cos_cache + nullptr, // sin_cache + nullptr, // block_table + parameters.batch_size, + parameters.num_heads, + parameters.num_heads, // num_heads_k + parameters.head_size, + parameters.sequence_length, // seqlen_q + parameters.total_sequence_length, // seqlen_k + 0, // seqlen_k_new + 0, // rotary_dim + scale, // softmax_scale + parameters.is_unidirectional, + is_bf16, + false, // past_bsnh + data.num_splits, + data.grid_dim_z, + data.max_tiles_per_tb, + data.high_load_tbs, + data.tiles_per_head, + reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), + data.lean_sync_flag, + -1, // local_window_size + false, // is_rotary_interleaved + false // is_packed_qkv + )); return Status::OK(); } @@ -414,8 +414,6 @@ Status LeanAttention( } #endif - - template Status CudnnFlashAttention( cudnnHandle_t cudnn_handle, @@ -439,21 +437,21 @@ Status CudnnFlashAttention( data.k, data.v, attention_bias, - nullptr, // (optional) mask_sequence_lengths_q - mask_sequence_lengths_kv, // (optional) mask_sequence_lengths_kv + nullptr, // (optional) mask_sequence_lengths_q + mask_sequence_lengths_kv, // (optional) mask_sequence_lengths_kv parameters.batch_size, - parameters.num_heads, // num_heads_q, - parameters.num_heads, // num_heads_kv, - parameters.head_size, // head_size_qk - parameters.v_head_size, // head_size_v - parameters.sequence_length, // sequence_length_q - parameters.total_sequence_length, // sequence_length_kv - scale, // scaling factor applied prior softmax - parameters.is_unidirectional, // causal - is_bf16, // True if bfloat16, otherwise float16 - parameters.broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0 or not - parameters.broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1 or not - 0, // sliding window length. 0 means no sliding window. + parameters.num_heads, // num_heads_q, + parameters.num_heads, // num_heads_kv, + parameters.head_size, // head_size_qk + parameters.v_head_size, // head_size_v + parameters.sequence_length, // sequence_length_q + parameters.total_sequence_length, // sequence_length_kv + scale, // scaling factor applied prior softmax + parameters.is_unidirectional, // causal + is_bf16, // True if bfloat16, otherwise float16 + parameters.broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0 or not + parameters.broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1 or not + 0, // sliding window length. 0 means no sliding window. data.qkv_format, cudnn_handle, ort_stream, @@ -540,10 +538,9 @@ Status EfficientAttention( template Status LaunchDecoderMaskedMultiHeadAttention( - const DecoderMaskedMultiHeadAttentionParameters& parameters, - cudaStream_t stream, - const int head_size) { - + const DecoderMaskedMultiHeadAttentionParameters& parameters, + cudaStream_t stream, + const int head_size) { DUMP_STRING_INIT(); DUMP_STRING("DMMHA parameters..."); DUMP_STRING("is_mha = ", (parameters.is_mha == true)); @@ -763,7 +760,7 @@ Status UnfusedAttention( if (nullptr != data.output_qk) { int64_t qk_size = (int64_t)batch_size * num_heads * sequence_length * total_sequence_length; ORT_RETURN_IF_ERROR( - (CopyQK(stream, static_cast(qk_size), data.scratch, reinterpret_cast(data.output_qk)))); + (CopyQK(stream, static_cast(qk_size), data.scratch, reinterpret_cast(data.output_qk)))); } ORT_RETURN_IF_ERROR( ComputeSoftmax( @@ -802,7 +799,7 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) // When there is past state, the head size for Q/K/V shall be same: H == H_v. - if (nullptr != data.present) { // Attention op + if (nullptr != data.present) { // Attention op assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); @@ -811,12 +808,10 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, max_threads_per_block, 2, data.past, data.k, data.present)); - - // Update pointers to present_k and present_v. data.k = data.present; data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; - } else { // MultiHeadAttention op + } else { // MultiHeadAttention op if (nullptr != data.present_key) { ORT_ENFORCE(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); @@ -826,16 +821,16 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int ORT_RETURN_IF_ERROR( LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, - max_threads_per_block, 1, data.past_key, data.k, data.present_key)); + batch_size, qk_head_size, num_heads, + max_threads_per_block, 1, data.past_key, data.k, data.present_key)); ORT_RETURN_IF_ERROR( LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, - batch_size, v_head_size, num_heads, - max_threads_per_block, 1, data.past_value, data.v, data.present_value)); + batch_size, v_head_size, num_heads, + max_threads_per_block, 1, data.past_value, data.v, data.present_value)); // Update pointers to present_k and present_v. data.k = data.present_key; data.v = data.present_value; - } else { // nullptr == data.past_key && nullptr != data.present_key + } else { // nullptr == data.past_key && nullptr != data.present_key if (data.k != data.present_key) { int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; cudaMemcpyAsync(data.present_key, data.k, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); @@ -889,7 +884,7 @@ Status PastPresentBufferShare(int batch_size, int num_heads, int qk_head_size, i return Status::OK(); } - if (combined_key_value) { // Attention op + if (combined_key_value) { // Attention op assert(data.gemm_buffer != nullptr); if (data.present != data.past) { @@ -924,9 +919,9 @@ Status PastPresentBufferShare(int batch_size, int num_heads, int qk_head_size, i constexpr bool is_past_kv_bnsh_format = true; constexpr bool is_new_kv_bnsh_format = true; ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace( - batch_size, num_heads, qk_head_size, parameters.max_sequence_length, - data.seqlens_k_total, nullptr, parameters.sequence_length, data.k, data.v, data.present_key, data.present_value, - is_past_kv_bnsh_format, is_new_kv_bnsh_format, stream, max_threads_per_block)); + batch_size, num_heads, qk_head_size, parameters.max_sequence_length, + data.seqlens_k_total, nullptr, parameters.sequence_length, data.k, data.v, data.present_key, data.present_value, + is_past_kv_bnsh_format, is_new_kv_bnsh_format, stream, max_threads_per_block)); data.k = data.present_key; data.v = data.present_value; @@ -981,13 +976,13 @@ Status QkvToContext( if (!parameters.past_present_share_buffer) { ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, - sequence_length, total_sequence_length, - stream, max_threads_per_block, data)); + sequence_length, total_sequence_length, + stream, max_threads_per_block, data)); } else { // past_present_share_buffer ORT_RETURN_IF_ERROR(PastPresentBufferShare(batch_size, num_heads, qk_head_size, v_head_size, - sequence_length, fused_runner, - parameters, data, stream, max_threads_per_block)); + sequence_length, fused_runner, + parameters, data, stream, max_threads_per_block)); } // Q, K and V are ready now @@ -1078,24 +1073,24 @@ template Status QkvToContext( AttentionData& data); template Status LaunchDecoderMaskedMultiHeadAttention( - const DecoderMaskedMultiHeadAttentionParameters& parameters, - cudaStream_t stream, - const int head_size); + const DecoderMaskedMultiHeadAttentionParameters& parameters, + cudaStream_t stream, + const int head_size); template Status LaunchDecoderMaskedMultiHeadAttention( - const DecoderMaskedMultiHeadAttentionParameters& parameters, - cudaStream_t stream, - const int head_size); + const DecoderMaskedMultiHeadAttentionParameters& parameters, + cudaStream_t stream, + const int head_size); template Status LaunchDecoderMaskedMultiHeadAttention( - const DecoderMaskedMultiHeadAttentionParameters& parameters, - cudaStream_t stream, - const int head_size); + const DecoderMaskedMultiHeadAttentionParameters& parameters, + cudaStream_t stream, + const int head_size); template Status LaunchDecoderMaskedMultiHeadAttention( - const DecoderMaskedMultiHeadAttentionParameters& parameters, - cudaStream_t stream, - const int head_size); + const DecoderMaskedMultiHeadAttentionParameters& parameters, + cudaStream_t stream, + const int head_size); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu index 1753929d60617..121ddcf779485 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -197,7 +197,6 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, return CUDA_CALL(cudaGetLastError()); } - #ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP // ---------------------------------------------------------------------------------- diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu index 9e00ca713a448..acc05c7053b5f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu @@ -3,7 +3,7 @@ Copyright (c) Microsoft Corporation. Licensed under the MIT License. */ /* -Kernel implementation for Gamma rotary embeddings. +Kernel implementation for Gamma rotary embeddings. This implementation below subgraph (emb) / \ @@ -16,7 +16,7 @@ This implementation below subgraph \/ \/ \/ \/ Mul Mul Mul Mul \ / \ / - Add Add + Add Add | | (output1) (output2) */ @@ -36,27 +36,26 @@ constexpr int kThreadsPerBlock = GridDim::maxThreadsPerBlock; template __global__ void GemmaRotaryEmb( - T* output1, - T* output2, - const U* emb, - const T* q, - const T* q_rot, - const T* k, - const T* k_rot, - const int batch_size, - const int num_heads, - const int seq_len, - const int dim) { - - const int qk_idx = blockIdx.x * blockDim.x + threadIdx.x; - // index [i, j, k, l] -> [i, k, l] - const int emb_idx = qk_idx / (num_heads * seq_len * dim) * (seq_len * dim) + qk_idx % (seq_len * dim); - if (qk_idx < batch_size * num_heads * seq_len * dim) { - T sin_val = static_cast(sin(emb[emb_idx])); - T cos_val = static_cast(cos(emb[emb_idx])); - output1[qk_idx] = q[qk_idx] * cos_val + q_rot[qk_idx] * sin_val; - output2[qk_idx] = k[qk_idx] * cos_val + k_rot[qk_idx] * sin_val; - } + T* output1, + T* output2, + const U* emb, + const T* q, + const T* q_rot, + const T* k, + const T* k_rot, + const int batch_size, + const int num_heads, + const int seq_len, + const int dim) { + const int qk_idx = blockIdx.x * blockDim.x + threadIdx.x; + // index [i, j, k, l] -> [i, k, l] + const int emb_idx = qk_idx / (num_heads * seq_len * dim) * (seq_len * dim) + qk_idx % (seq_len * dim); + if (qk_idx < batch_size * num_heads * seq_len * dim) { + T sin_val = static_cast(sin(emb[emb_idx])); + T cos_val = static_cast(cos(emb[emb_idx])); + output1[qk_idx] = q[qk_idx] * cos_val + q_rot[qk_idx] * sin_val; + output2[qk_idx] = k[qk_idx] * cos_val + k_rot[qk_idx] * sin_val; + } } template @@ -72,15 +71,13 @@ Status LaunchGemmaRotaryEmbeddingKernel( const int batch_size, const int num_heads, const int seq_len, - const int dim - ) { + const int dim) { int blocksPerGrid = static_cast(ceil(float(batch_size * num_heads * seq_len * dim) / kThreadsPerBlock)); GemmaRotaryEmb<<>>( - output1, output2, - emb, q, q_rot, k, k_rot, - batch_size, num_heads, seq_len, dim - ); + output1, output2, + emb, q, q_rot, k, k_rot, + batch_size, num_heads, seq_len, dim); return CUDA_CALL(cudaGetLastError()); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 2d1b49033003d..bb450e476d5ba 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -109,7 +109,7 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, const int max_threads_per_block) { const int max_sequence_length = parameters.seqlen_present_kv_cache; const int* seqlens_k = (parameters.is_first_prompt && !parameters.is_subsequent_prompt) ? nullptr - : reinterpret_cast(data.seqlens_k); + : reinterpret_cast(data.seqlens_k); assert(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu index cfcacbabb3cb9..42de42389145f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu @@ -8,8 +8,8 @@ namespace onnxruntime { namespace lean { -template void run_mha_fwd_lean_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_lean_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash +} // namespace lean } // namespace onnxruntime #endif diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu index 44c870f6ab35b..7b0d3ce3b93db 100644 --- a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu @@ -8,8 +8,8 @@ namespace onnxruntime { namespace lean { -template void run_mha_fwd_lean_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_lean_dispatch(Flash_fwd_params& params, cudaStream_t stream); -} // namespace flash +} // namespace lean } // namespace onnxruntime #endif diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu index 7ecdf51bdde11..a1eadfcedc6c6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu @@ -87,11 +87,11 @@ Status LaunchUnpackCumulative(const T* input, T* output, const int token_count, } template -__global__ void RotaryEmbeddingTNH(T* output, // TxNxH - const T* input, // TxNxH - const T* cos_cache, // Mx(H/2) - const T* sin_cache, // Mx(H/2) - const int32_t* past_seqlens, // B +__global__ void RotaryEmbeddingTNH(T* output, // TxNxH + const T* input, // TxNxH + const T* cos_cache, // Mx(H/2) + const T* sin_cache, // Mx(H/2) + const int32_t* past_seqlens, // B const int32_t* cumulative_seqlens_q, // B+1 const int head_size, const int rotary_embedding_dim, @@ -110,7 +110,7 @@ __global__ void RotaryEmbeddingTNH(T* output, // TxNxH return; } - const int t = cumulative_seqlens_q[b] + s; // t is the index of the token in the unpadded input/output + const int t = cumulative_seqlens_q[b] + s; // t is the index of the token in the unpadded input/output const T* input_data = input + t * in_strides.x + n * in_strides.y; T* output_data = output + t * out_strides.x + n * out_strides.y; @@ -213,7 +213,7 @@ __global__ void ReshapeAndCache(const T* __restrict__ key, const T* __restrict__ } const int token_offset = token_id - cumulative_seqlens_q[batch_id]; const int past_length = past_seqlens[batch_id]; - const int block_id = block_table[batch_id * max_num_blocks_per_seq + (past_length + token_offset) / block_size]; + const int block_id = block_table[batch_id * max_num_blocks_per_seq + (past_length + token_offset) / block_size]; const int block_offset = (past_length + token_offset) % block_size; const int key_id = token_id * key_stride + hidden_offset; diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index a1dcab0a6bf89..88ab3b5831afe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -212,7 +212,7 @@ void LaunchSkipLayerNormKernel( #define CASE_NEXT_SIZE(next_size_value) \ case next_size_value: { \ static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \ - if constexpr (next_size_value >= 320) { \ + if constexpr (next_size_value >= 320) { \ if (can_unroll_vec8) { \ constexpr int block_size = next_size_value / 8; \ LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \ diff --git a/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu b/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu index 9de3d48417b34..9678159532eb1 100644 --- a/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu +++ b/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu @@ -36,25 +36,25 @@ using namespace onnxruntime::cuda; //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ void st_flag_release(uint32_t const &flag, uint32_t *flag_addr) { +static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) { #if __CUDA_ARCH__ >= 700 - asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); + asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); #else - __threadfence_system(); - asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); + __threadfence_system(); + asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr) { - uint32_t flag; +static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) { + uint32_t flag; #if __CUDA_ARCH__ >= 700 - asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); + asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); #else - asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); + asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); #endif - return flag; + return flag; } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -62,565 +62,569 @@ static inline __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr) { // Type Converter that packs data format to 128 bits data type // using PackedFloat = union { - int4 packed; - float unpacked[4]; + int4 packed; + float unpacked[4]; }; using PackedHalf = union { - int4 packed; - half2 unpacked[4]; + int4 packed; + half2 unpacked[4]; }; -template struct PackedOn16Bytes {}; +template +struct PackedOn16Bytes {}; -template <> struct PackedOn16Bytes { - using Type = PackedFloat; +template <> +struct PackedOn16Bytes { + using Type = PackedFloat; }; -template <> struct PackedOn16Bytes { - using Type = PackedHalf; +template <> +struct PackedOn16Bytes { + using Type = PackedHalf; }; // add two 128b data -template inline __device__ int4 add128b(T &a, T &b) { - T c; - c.unpacked[0] = a.unpacked[0] + b.unpacked[0]; - c.unpacked[1] = a.unpacked[1] + b.unpacked[1]; - c.unpacked[2] = a.unpacked[2] + b.unpacked[2]; - c.unpacked[3] = a.unpacked[3] + b.unpacked[3]; - return c.packed; +template +inline __device__ int4 add128b(T& a, T& b) { + T c; + c.unpacked[0] = a.unpacked[0] + b.unpacked[0]; + c.unpacked[1] = a.unpacked[1] + b.unpacked[1]; + c.unpacked[2] = a.unpacked[2] + b.unpacked[2]; + c.unpacked[3] = a.unpacked[3] + b.unpacked[3]; + return c.packed; } -__inline__ __device__ void multi_gpu_barrier(uint32_t **signals, uint32_t const flag, size_t const local_rank, +__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, size_t const world_size, int const tidx, int const bidx) { - // After this function, at least one block in each GPU has reached the barrier - if (tidx < world_size) { - // we can think of signals having the shape [world_size, world_size] - // Dimension 0 is the "listening" dimension, dimension 2 is "emitting" dimension - - // Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers - if (bidx == 0) { - signals[tidx][local_rank] = flag; - } + // After this function, at least one block in each GPU has reached the barrier + if (tidx < world_size) { + // we can think of signals having the shape [world_size, world_size] + // Dimension 0 is the "listening" dimension, dimension 2 is "emitting" dimension + + // Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers + if (bidx == 0) { + signals[tidx][local_rank] = flag; + } - // All blocks check that corresponding block 0 on other GPUs have set the flag - // No deadlock because block #0 is always the first block started - uint32_t volatile *my_signals = signals[local_rank]; - while (my_signals[tidx] != flag) { - } + // All blocks check that corresponding block 0 on other GPUs have set the flag + // No deadlock because block #0 is always the first block started + uint32_t volatile* my_signals = signals[local_rank]; + while (my_signals[tidx] != flag) { } + } - __syncthreads(); + __syncthreads(); } -__inline__ __device__ void block_barrier(uint32_t **signals, uint32_t const flag, size_t const local_rank, +__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, size_t const world_size, int const tidx, int const bidx) { - // After this function, the block of id == bidx of each GPU has reached the barrier - if (tidx < world_size) { - // we can think of signals having the shape [world_size, num_blocks, world_size] - // (+ an offset on dim 1 to account for flags used in multi_gpu_barrier) - // Dimension 0 is the "listening" dimension, dimension 2 is "emitting" dimension - - // Block broadcast its flag (local_rank on emitting dimension) to all receivers - uint32_t flag_block_offset = world_size + bidx * world_size; - st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank); - - // Blocks check that corresponding blocks on other GPUs have also set the flag - uint32_t *peer_barrier_d = signals[local_rank] + flag_block_offset + tidx; - while (ld_flag_acquire(peer_barrier_d) != flag) { - } + // After this function, the block of id == bidx of each GPU has reached the barrier + if (tidx < world_size) { + // we can think of signals having the shape [world_size, num_blocks, world_size] + // (+ an offset on dim 1 to account for flags used in multi_gpu_barrier) + // Dimension 0 is the "listening" dimension, dimension 2 is "emitting" dimension + + // Block broadcast its flag (local_rank on emitting dimension) to all receivers + uint32_t flag_block_offset = world_size + bidx * world_size; + st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank); + + // Blocks check that corresponding blocks on other GPUs have also set the flag + uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx; + while (ld_flag_acquire(peer_barrier_d) != flag) { } + } - __syncthreads(); + __syncthreads(); } template static __global__ void oneShotAllReduceKernel(AllReduceParams params) { - // Suppose that two GPUs participate in the AR exchange, and we start four blocks. - // The message is partitioned into chunks as detailed below: - // message - // |-------------------| - // GPU 0 | B0 | B1 | B2 | B3 | - // GPU 1 | B0 | B1 | B2 | B3 | - // - // Here the step-by-step behavior of one block: - // 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer - // 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier) - // 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output - // - // With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2. - // We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready - // - // With PUSH_MODE, we consider that the shared buffer is of size: - // params.peer_comm_buffer_ptrs: [world_size, world_size, message_size] - // - // Here the step-by-step behavior of one block: - // 1. B0 push the chunk is it responsible for into all other GPUs: - // params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice] - // 2. block sync so the block is shared by other GPUs - // 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice] - - int const bidx = blockIdx.x; - int const tidx = threadIdx.x; - - // The number of elements packed into one for comms - static constexpr int PACKED_ELTS = 16 / sizeof(T); - using PackedStruct = typename PackedOn16Bytes::Type; - - [[maybe_unused]] T const *local_input_buffer = reinterpret_cast(params.local_input_buffer_ptr); - [[maybe_unused]] T *local_shared_buffer = reinterpret_cast(params.peer_comm_buffer_ptrs[params.local_rank]); - [[maybe_unused]] T *local_output_buffer = reinterpret_cast(params.local_output_buffer_ptr); - - // Start and end offsets of the thread - size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS; - size_t const chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_total); - - T *buffers[RANKS_PER_NODE]; + // Suppose that two GPUs participate in the AR exchange, and we start four blocks. + // The message is partitioned into chunks as detailed below: + // message + // |-------------------| + // GPU 0 | B0 | B1 | B2 | B3 | + // GPU 1 | B0 | B1 | B2 | B3 | + // + // Here the step-by-step behavior of one block: + // 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer + // 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier) + // 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output + // + // With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2. + // We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready + // + // With PUSH_MODE, we consider that the shared buffer is of size: + // params.peer_comm_buffer_ptrs: [world_size, world_size, message_size] + // + // Here the step-by-step behavior of one block: + // 1. B0 push the chunk is it responsible for into all other GPUs: + // params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice] + // 2. block sync so the block is shared by other GPUs + // 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice] + + int const bidx = blockIdx.x; + int const tidx = threadIdx.x; + + // The number of elements packed into one for comms + static constexpr int PACKED_ELTS = 16 / sizeof(T); + using PackedStruct = typename PackedOn16Bytes::Type; + + [[maybe_unused]] T const* local_input_buffer = reinterpret_cast(params.local_input_buffer_ptr); + [[maybe_unused]] T* local_shared_buffer = reinterpret_cast(params.peer_comm_buffer_ptrs[params.local_rank]); + [[maybe_unused]] T* local_output_buffer = reinterpret_cast(params.local_output_buffer_ptr); + + // Start and end offsets of the thread + size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS; + size_t const chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_total); + + T* buffers[RANKS_PER_NODE]; #pragma unroll - for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - // buffers[0] is always the local buffers. Helps load balancing reads. - int rank = (params.local_rank + ii) % RANKS_PER_NODE; - buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); - } - - if constexpr (PUSH_MODE || COPY_INPUT) { - // Copy from local buffer to shareable buffer - for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * PACKED_ELTS) { - if constexpr (PUSH_MODE) { -#pragma unroll - for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - *reinterpret_cast(&buffers[ii][params.local_rank * params.elts_total + iter_offset]) = - *reinterpret_cast(&local_input_buffer[iter_offset]); - } - } else { - *reinterpret_cast(&local_shared_buffer[iter_offset]) = - *reinterpret_cast(&local_input_buffer[iter_offset]); - } - } - // wait for equivalent blocks of other GPUs to have copied data to their shareable buffer - block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); - } else { - // In the non-copy case, we assume that once the kernel has been started, data is ready to be consumed - multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, - bidx); - } - - // Each block accumulates the values from the different GPUs on the same node. + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + // buffers[0] is always the local buffers. Helps load balancing reads. + int rank = (params.local_rank + ii) % RANKS_PER_NODE; + buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); + } + + if constexpr (PUSH_MODE || COPY_INPUT) { + // Copy from local buffer to shareable buffer for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * PACKED_ELTS) { - // Iterate over the different ranks/devices on the node to load the values. - PackedStruct vals[RANKS_PER_NODE]; + if constexpr (PUSH_MODE) { #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - if constexpr (PUSH_MODE) { - vals[ii].packed = - *reinterpret_cast(&buffers[params.local_rank][ii * params.elts_total + iter_offset]); - } else { - vals[ii].packed = *reinterpret_cast(&buffers[ii][iter_offset]); - } - } - - // Sum the values from the different ranks. - PackedStruct sums; - sums.packed = {0, 0, 0, 0}; -#pragma unroll - for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { - // Always reduce from rank 0 to ensure stable reduce order. - int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE; - sums.packed = add128b(sums, vals[ii]); + *reinterpret_cast(&buffers[ii][params.local_rank * params.elts_total + iter_offset]) = + *reinterpret_cast(&local_input_buffer[iter_offset]); } - - // Store to the destination buffer. - *reinterpret_cast(&local_output_buffer[iter_offset]) = sums.packed; + } else { + *reinterpret_cast(&local_shared_buffer[iter_offset]) = + *reinterpret_cast(&local_input_buffer[iter_offset]); + } } -} - -template -static __global__ void twoShotAllReduceKernel(AllReduceParams params) { - // Suppose that two GPUs participate in the AR exchange, and we start two blocks. - // The message is partitioned into chunks as detailed below: - // message - // |-------------------| - // |--GPU 0--|--GPU 1--| (GPU responsibility parts) - // GPU 0 | B0 | B1 | B0 | B1 | - // GPU 1 | B0 | B1 | B0 | B1 | - // - // Here the step-by-step behavior of one block: - // 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer - // 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0) - // 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility - // part (the first half of the message, see GPU responsibility row above) - // 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0, - // where GPU 1 is responsible: the second half of the message. - // 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1) - // 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU. - // For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1. - // - // With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2. - // We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready - // to be read. - // - // Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks. - // However, it's only responsible for the summation of a single chunk. - // - // With PUSH_MODE, we consider that the shared buffer is of size: - // params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size] - // - // Here the step-by-step behavior of one block: - // 1. B0 push the chunks is it responsible for into the corresponding GPUs: - // params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice] - // 2. block sync so the blocks have been shared by other GPUs - // 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice] - // 4. block barrier (corresponding blocks have finished reduction) - // 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is - // written at index 0 of 2nd dim) - - int const bidx = blockIdx.x; - int const tidx = threadIdx.x; - - // The number of elements packed into one for comms - static constexpr int PACKED_ELTS = 16 / sizeof(T); - using PackedType = typename PackedOn16Bytes::Type; - - [[maybe_unused]] T const *local_input_buffer = reinterpret_cast(params.local_input_buffer_ptr); - [[maybe_unused]] T *local_shared_buffer = reinterpret_cast(params.peer_comm_buffer_ptrs[params.local_rank]); - [[maybe_unused]] T *local_output_buffer = reinterpret_cast(params.local_output_buffer_ptr); - - size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS; - size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank); - - T *buffers[RANKS_PER_NODE]; - int ranks[RANKS_PER_NODE]; + // wait for equivalent blocks of other GPUs to have copied data to their shareable buffer + block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); + } else { + // In the non-copy case, we assume that once the kernel has been started, data is ready to be consumed + multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, + bidx); + } + + // Each block accumulates the values from the different GPUs on the same node. + for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * PACKED_ELTS) { + // Iterate over the different ranks/devices on the node to load the values. + PackedStruct vals[RANKS_PER_NODE]; #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - // A mapping of the ranks to scatter reads as much as possible - int rank = (params.local_rank + ii) % RANKS_PER_NODE; - ranks[ii] = rank; - buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); + if constexpr (PUSH_MODE) { + vals[ii].packed = + *reinterpret_cast(&buffers[params.local_rank][ii * params.elts_total + iter_offset]); + } else { + vals[ii].packed = *reinterpret_cast(&buffers[ii][iter_offset]); + } } - if constexpr (PUSH_MODE || COPY_INPUT) { - // Copy all blocks from local buffer to shareable buffer - for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { + // Sum the values from the different ranks. + PackedStruct sums; + sums.packed = {0, 0, 0, 0}; #pragma unroll - for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - size_t offset_rank = ii * params.elts_per_rank + local_offset; - if (offset_rank >= params.elts_total) { - continue; - } - - if constexpr (PUSH_MODE) { - *reinterpret_cast(&buffers[ii][params.local_rank * params.elts_per_rank + local_offset]) = - *reinterpret_cast(&local_input_buffer[offset_rank]); - } else { - *reinterpret_cast(&local_shared_buffer[offset_rank]) = - *reinterpret_cast(&local_input_buffer[offset_rank]); - } - } - } - block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); - } else { - // In the non-copy case, we assume that once the kernel has been started, data is ready to be consumed - multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, - bidx); + for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { + // Always reduce from rank 0 to ensure stable reduce order. + int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE; + sums.packed = add128b(sums, vals[ii]); } - // Each block accumulates the values from the different GPUs on the same node. - for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { - size_t const responsible_block_offset = local_offset + params.rank_offset; + // Store to the destination buffer. + *reinterpret_cast(&local_output_buffer[iter_offset]) = sums.packed; + } +} - // Iterate over the different ranks/devices on the node to load the values. - PackedType vals[RANKS_PER_NODE]; +template +static __global__ void twoShotAllReduceKernel(AllReduceParams params) { + // Suppose that two GPUs participate in the AR exchange, and we start two blocks. + // The message is partitioned into chunks as detailed below: + // message + // |-------------------| + // |--GPU 0--|--GPU 1--| (GPU responsibility parts) + // GPU 0 | B0 | B1 | B0 | B1 | + // GPU 1 | B0 | B1 | B0 | B1 | + // + // Here the step-by-step behavior of one block: + // 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer + // 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0) + // 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility + // part (the first half of the message, see GPU responsibility row above) + // 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0, + // where GPU 1 is responsible: the second half of the message. + // 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1) + // 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU. + // For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1. + // + // With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2. + // We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready + // to be read. + // + // Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks. + // However, it's only responsible for the summation of a single chunk. + // + // With PUSH_MODE, we consider that the shared buffer is of size: + // params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size] + // + // Here the step-by-step behavior of one block: + // 1. B0 push the chunks is it responsible for into the corresponding GPUs: + // params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice] + // 2. block sync so the blocks have been shared by other GPUs + // 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice] + // 4. block barrier (corresponding blocks have finished reduction) + // 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is + // written at index 0 of 2nd dim) + + int const bidx = blockIdx.x; + int const tidx = threadIdx.x; + + // The number of elements packed into one for comms + static constexpr int PACKED_ELTS = 16 / sizeof(T); + using PackedType = typename PackedOn16Bytes::Type; + + [[maybe_unused]] T const* local_input_buffer = reinterpret_cast(params.local_input_buffer_ptr); + [[maybe_unused]] T* local_shared_buffer = reinterpret_cast(params.peer_comm_buffer_ptrs[params.local_rank]); + [[maybe_unused]] T* local_output_buffer = reinterpret_cast(params.local_output_buffer_ptr); + + size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS; + size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank); + + T* buffers[RANKS_PER_NODE]; + int ranks[RANKS_PER_NODE]; #pragma unroll - for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - if constexpr (PUSH_MODE) { - vals[ii].packed = - *reinterpret_cast(&local_shared_buffer[ii * params.elts_per_rank + local_offset]); - } else { - vals[ii].packed = *reinterpret_cast(&buffers[ii][responsible_block_offset]); - } - } - - // Sum the values from the different ranks. - PackedType sums; - sums.packed = {0, 0, 0, 0}; + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + // A mapping of the ranks to scatter reads as much as possible + int rank = (params.local_rank + ii) % RANKS_PER_NODE; + ranks[ii] = rank; + buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); + } + + if constexpr (PUSH_MODE || COPY_INPUT) { + // Copy all blocks from local buffer to shareable buffer + for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { #pragma unroll - for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { - // Always reduce from rank 0 to ensure stable reduce order. - int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE; - sums.packed = add128b(sums, vals[ii]); + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + size_t offset_rank = ii * params.elts_per_rank + local_offset; + if (offset_rank >= params.elts_total) { + continue; } - // Store to the local buffer. if constexpr (PUSH_MODE) { - *reinterpret_cast(&local_shared_buffer[local_offset]) = sums.packed; + *reinterpret_cast(&buffers[ii][params.local_rank * params.elts_per_rank + local_offset]) = + *reinterpret_cast(&local_input_buffer[offset_rank]); } else { - *reinterpret_cast(&local_shared_buffer[responsible_block_offset]) = sums.packed; + *reinterpret_cast(&local_shared_buffer[offset_rank]) = + *reinterpret_cast(&local_input_buffer[offset_rank]); } + } + } + block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); + } else { + // In the non-copy case, we assume that once the kernel has been started, data is ready to be consumed + multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, + bidx); + } + + // Each block accumulates the values from the different GPUs on the same node. + for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { + size_t const responsible_block_offset = local_offset + params.rank_offset; + + // Iterate over the different ranks/devices on the node to load the values. + PackedType vals[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + if constexpr (PUSH_MODE) { + vals[ii].packed = + *reinterpret_cast(&local_shared_buffer[ii * params.elts_per_rank + local_offset]); + } else { + vals[ii].packed = *reinterpret_cast(&buffers[ii][responsible_block_offset]); + } } - block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); + // Sum the values from the different ranks. + PackedType sums; + sums.packed = {0, 0, 0, 0}; +#pragma unroll + for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { + // Always reduce from rank 0 to ensure stable reduce order. + int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE; + sums.packed = add128b(sums, vals[ii]); + } - // Gather all needed elts from other intra-node ranks - for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { + // Store to the local buffer. + if constexpr (PUSH_MODE) { + *reinterpret_cast(&local_shared_buffer[local_offset]) = sums.packed; + } else { + *reinterpret_cast(&local_shared_buffer[responsible_block_offset]) = sums.packed; + } + } + + block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); + + // Gather all needed elts from other intra-node ranks + for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { #pragma unroll - for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - // use round-robin gathering from other ranks - size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset; - if (offset_rank >= params.elts_total) { - continue; - } - - if constexpr (PUSH_MODE) { - *reinterpret_cast(&local_output_buffer[offset_rank]) = - *reinterpret_cast(&buffers[ii][local_offset]); - } else { - *reinterpret_cast(&local_output_buffer[offset_rank]) = - *reinterpret_cast(&buffers[ii][offset_rank]); - } - } + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + // use round-robin gathering from other ranks + size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset; + if (offset_rank >= params.elts_total) { + continue; + } + + if constexpr (PUSH_MODE) { + *reinterpret_cast(&local_output_buffer[offset_rank]) = + *reinterpret_cast(&buffers[ii][local_offset]); + } else { + *reinterpret_cast(&local_output_buffer[offset_rank]) = + *reinterpret_cast(&buffers[ii][offset_rank]); + } } + } } bool ConfigurationSupported(AllReduceStrategyType algo, size_t msg_size, size_t world_size, onnxruntime::MLDataType type) { - size_t elts_per_thread = 16 / type->Size(); - int const msg_align = (algo == AllReduceStrategyType::TWOSHOT) ? world_size * elts_per_thread : elts_per_thread; - bool supported_algo = (algo == AllReduceStrategyType::ONESHOT || algo == AllReduceStrategyType::TWOSHOT); - return supported_algo && (msg_size % msg_align == 0); + size_t elts_per_thread = 16 / type->Size(); + int const msg_align = (algo == AllReduceStrategyType::TWOSHOT) ? world_size * elts_per_thread : elts_per_thread; + bool supported_algo = (algo == AllReduceStrategyType::ONESHOT || algo == AllReduceStrategyType::TWOSHOT); + return supported_algo && (msg_size % msg_align == 0); } -std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams ¶m, size_t elts_per_thread) { - int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE; +std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& param, size_t elts_per_thread) { + int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE; - switch (algo) { + switch (algo) { case AllReduceStrategyType::ONESHOT: { - ORT_ENFORCE(param.elts_total % elts_per_thread == 0); - size_t const total_threads = roundUp(param.elts_total / elts_per_thread, WARP_SIZE); - threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads); - blocks_per_grid = std::min(static_cast(MAX_ALL_REDUCE_BLOCKS), - divUp(total_threads, static_cast(threads_per_block))); - param.elts_per_block = roundUp(divUp(param.elts_total, static_cast(blocks_per_grid)), elts_per_thread); - break; + ORT_ENFORCE(param.elts_total % elts_per_thread == 0); + size_t const total_threads = roundUp(param.elts_total / elts_per_thread, WARP_SIZE); + threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads); + blocks_per_grid = std::min(static_cast(MAX_ALL_REDUCE_BLOCKS), + divUp(total_threads, static_cast(threads_per_block))); + param.elts_per_block = roundUp(divUp(param.elts_total, static_cast(blocks_per_grid)), elts_per_thread); + break; } case AllReduceStrategyType::TWOSHOT: { - ORT_ENFORCE(param.elts_total % (elts_per_thread * param.ranks_per_node) == 0); - size_t const total_threads = roundUp(param.elts_total / (elts_per_thread * param.ranks_per_node), WARP_SIZE); + ORT_ENFORCE(param.elts_total % (elts_per_thread * param.ranks_per_node) == 0); + size_t const total_threads = roundUp(param.elts_total / (elts_per_thread * param.ranks_per_node), WARP_SIZE); - /* - threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads); - blocks_per_grid = std::min(static_cast(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block)); - */ + /* + threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads); + blocks_per_grid = std::min(static_cast(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block)); + */ - while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) { - blocks_per_grid += 1; - } + while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) { + blocks_per_grid += 1; + } - threads_per_block = total_threads / blocks_per_grid; + threads_per_block = total_threads / blocks_per_grid; - // NOTE: need to adjust here - if (static_cast(blocks_per_grid) > MAX_ALL_REDUCE_BLOCKS) { - size_t iter_factor = 1; - while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) { - iter_factor += 1; - } - blocks_per_grid /= iter_factor; + // NOTE: need to adjust here + if (static_cast(blocks_per_grid) > MAX_ALL_REDUCE_BLOCKS) { + size_t iter_factor = 1; + while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) { + iter_factor += 1; } - param.elts_per_rank = param.elts_total / param.ranks_per_node; - param.rank_offset = param.local_rank * param.elts_per_rank; - param.elts_per_block = - roundUp(divUp(param.elts_per_rank, static_cast(blocks_per_grid)), elts_per_thread); - break; + blocks_per_grid /= iter_factor; + } + param.elts_per_rank = param.elts_total / param.ranks_per_node; + param.rank_offset = param.local_rank * param.elts_per_rank; + param.elts_per_block = + roundUp(divUp(param.elts_per_rank, static_cast(blocks_per_grid)), elts_per_thread); + break; } default: - ORT_THROW("Algorithm not supported here."); - } + ORT_THROW("Algorithm not supported here."); + } - return std::make_tuple(blocks_per_grid, threads_per_block); + return std::make_tuple(blocks_per_grid, threads_per_block); } template -void AllReduceDispatchMemcpy(AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceParams ¶m, +void AllReduceDispatchMemcpy(AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceParams& param, cudaStream_t stream) { - ORT_ENFORCE(!(USE_MEMCPY && PUSH_MODE), "Memcpy cannot be used with PUSH_MODE."); - size_t elts_per_thread = 16 / sizeof(T); - auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(algo, param, elts_per_thread); - - if (USE_MEMCPY) { - cudaMemcpyAsync(param.peer_comm_buffer_ptrs[param.local_rank], param.local_input_buffer_ptr, - param.elts_total * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } - - if (algo == AllReduceStrategyType::ONESHOT) { - oneShotAllReduceKernel - <<>>(param); - } else { - twoShotAllReduceKernel - <<>>(param); - } + ORT_ENFORCE(!(USE_MEMCPY && PUSH_MODE), "Memcpy cannot be used with PUSH_MODE."); + size_t elts_per_thread = 16 / sizeof(T); + auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(algo, param, elts_per_thread); + + if (USE_MEMCPY) { + cudaMemcpyAsync(param.peer_comm_buffer_ptrs[param.local_rank], param.local_input_buffer_ptr, + param.elts_total * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } + + if (algo == AllReduceStrategyType::ONESHOT) { + oneShotAllReduceKernel + <<>>(param); + } else { + twoShotAllReduceKernel + <<>>(param); + } } template -void AllReduceDispatchPushMode(AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceParams ¶m, +void AllReduceDispatchPushMode(AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceParams& param, cudaStream_t stream) { - if (static_cast>(config) & - static_cast>(AllReduceStrategyConfig::USE_MEMCPY)) { - AllReduceDispatchMemcpy(algo, config, param, stream); - } else { - AllReduceDispatchMemcpy(algo, config, param, stream); - } + if (static_cast>(config) & + static_cast>(AllReduceStrategyConfig::USE_MEMCPY)) { + AllReduceDispatchMemcpy(algo, config, param, stream); + } else { + AllReduceDispatchMemcpy(algo, config, param, stream); + } } -template //, bool USE_MEMCPY = false, bool PUSH_MODE = false> -void AllReduceDispatchRanksPerNode(AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceParams ¶m, +template //, bool USE_MEMCPY = false, bool PUSH_MODE = false> +void AllReduceDispatchRanksPerNode(AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceParams& param, cudaStream_t stream) { - if (static_cast>(config) & - static_cast>(AllReduceStrategyConfig::PUSH_MODE)) { - AllReduceDispatchPushMode(algo, config, param, stream); - } else { - AllReduceDispatchPushMode(algo, config, param, stream); - } + if (static_cast>(config) & + static_cast>(AllReduceStrategyConfig::PUSH_MODE)) { + AllReduceDispatchPushMode(algo, config, param, stream); + } else { + AllReduceDispatchPushMode(algo, config, param, stream); + } } template -void AllReduceDispatchType(AllReduceParams ¶m, AllReduceStrategyType strategy, AllReduceStrategyConfig config, +void AllReduceDispatchType(AllReduceParams& param, AllReduceStrategyType strategy, AllReduceStrategyConfig config, cudaStream_t stream) { - switch (param.ranks_per_node) { + switch (param.ranks_per_node) { case 2: - AllReduceDispatchRanksPerNode(strategy, config, param, stream); - break; + AllReduceDispatchRanksPerNode(strategy, config, param, stream); + break; case 4: - AllReduceDispatchRanksPerNode(strategy, config, param, stream); - break; + AllReduceDispatchRanksPerNode(strategy, config, param, stream); + break; case 6: - AllReduceDispatchRanksPerNode(strategy, config, param, stream); - break; + AllReduceDispatchRanksPerNode(strategy, config, param, stream); + break; case 8: - AllReduceDispatchRanksPerNode(strategy, config, param, stream); - break; + AllReduceDispatchRanksPerNode(strategy, config, param, stream); + break; default: - ORT_THROW("Custom all reduce only supported on {2, 4, 6, 8} GPUs per node."); - } + ORT_THROW("Custom all reduce only supported on {2, 4, 6, 8} GPUs per node."); + } } -AllReduceParams AllReduceParams::deserialize(const int32_t *buffer, size_t tp_size, size_t tp_rank, uint32_t flag) { - void *const *buffer_ptrs = reinterpret_cast(buffer); - AllReduceParams params; - - for (size_t i = 0; i < tp_size; ++i) { - params.peer_comm_buffer_ptrs[i] = buffer_ptrs[i]; - } - for (size_t i = 0; i < tp_size; ++i) { - params.peer_barrier_ptrs_in[i] = reinterpret_cast(buffer_ptrs[tp_size + i]); - } - for (size_t i = 0; i < tp_size; ++i) { - params.peer_barrier_ptrs_out[i] = reinterpret_cast(buffer_ptrs[2 * tp_size + i]); - } - params.barrier_flag = flag; - params.ranks_per_node = tp_size; - params.rank = tp_rank; - params.local_rank = tp_rank; - - return params; +AllReduceParams AllReduceParams::deserialize(const int32_t* buffer, size_t tp_size, size_t tp_rank, uint32_t flag) { + void* const* buffer_ptrs = reinterpret_cast(buffer); + AllReduceParams params; + + for (size_t i = 0; i < tp_size; ++i) { + params.peer_comm_buffer_ptrs[i] = buffer_ptrs[i]; + } + for (size_t i = 0; i < tp_size; ++i) { + params.peer_barrier_ptrs_in[i] = reinterpret_cast(buffer_ptrs[tp_size + i]); + } + for (size_t i = 0; i < tp_size; ++i) { + params.peer_barrier_ptrs_out[i] = reinterpret_cast(buffer_ptrs[2 * tp_size + i]); + } + params.barrier_flag = flag; + params.ranks_per_node = tp_size; + params.rank = tp_rank; + params.local_rank = tp_rank; + + return params; } -void CustomAllReduce(AllReduceParams ¶ms, onnxruntime::MLDataType data_type, AllReduceStrategyType strategy, +void CustomAllReduce(AllReduceParams& params, onnxruntime::MLDataType data_type, AllReduceStrategyType strategy, AllReduceStrategyConfig config, cudaStream_t stream) { - ORT_ENFORCE(ConfigurationSupported(strategy, params.elts_total, params.ranks_per_node, data_type), - "Custom all-reduce configuration unsupported"); - if (data_type == onnxruntime::DataTypeImpl::GetType()) { - AllReduceDispatchType(params, strategy, config, stream); - } else if (data_type == onnxruntime::DataTypeImpl::GetType()) { - AllReduceDispatchType(params, strategy, config, stream); - } else { - ORT_THROW("Unsupported data type for CustomAllReduce"); - } + ORT_ENFORCE(ConfigurationSupported(strategy, params.elts_total, params.ranks_per_node, data_type), + "Custom all-reduce configuration unsupported"); + if (data_type == onnxruntime::DataTypeImpl::GetType()) { + AllReduceDispatchType(params, strategy, config, stream); + } else if (data_type == onnxruntime::DataTypeImpl::GetType()) { + AllReduceDispatchType(params, strategy, config, stream); + } else { + ORT_THROW("Unsupported data type for CustomAllReduce"); + } } size_t GetMaxRequiredWorkspaceSize(int world_size) { - if (world_size <= 2) { - return 16 * 1000 * 1000; - } - return 8 * 1000 * 1000; + if (world_size <= 2) { + return 16 * 1000 * 1000; + } + return 8 * 1000 * 1000; } -Status SetPeerAccess(int rank, int world_size, bool enable, int &can_access_peer) { - const int src_node = rank; +Status SetPeerAccess(int rank, int world_size, bool enable, int& can_access_peer) { + const int src_node = rank; - for (int dst_node = 0; dst_node < world_size; dst_node++) { - if (dst_node == src_node) { - continue; - } + for (int dst_node = 0; dst_node < world_size; dst_node++) { + if (dst_node == src_node) { + continue; + } - CUDA_RETURN_IF_ERROR(cudaDeviceCanAccessPeer(&can_access_peer, src_node, dst_node)); + CUDA_RETURN_IF_ERROR(cudaDeviceCanAccessPeer(&can_access_peer, src_node, dst_node)); - if (!can_access_peer) { - return Status::OK(); - } + if (!can_access_peer) { + return Status::OK(); + } - if (enable) { - cudaDeviceEnablePeerAccess(dst_node, 0); - } else { - cudaDeviceDisablePeerAccess(dst_node); - } + if (enable) { + cudaDeviceEnablePeerAccess(dst_node, 0); + } else { + cudaDeviceDisablePeerAccess(dst_node); + } - auto const error = cudaGetLastError(); - if (error != cudaErrorPeerAccessAlreadyEnabled && error != cudaErrorPeerAccessNotEnabled) { - CUDA_RETURN_IF_ERROR(error); - } + auto const error = cudaGetLastError(); + if (error != cudaErrorPeerAccessAlreadyEnabled && error != cudaErrorPeerAccessNotEnabled) { + CUDA_RETURN_IF_ERROR(error); } + } - return Status::OK(); + return Status::OK(); } AllReduceStrategyType SelectImplementation(size_t message_size, int rank, int world_size, onnxruntime::MLDataType type) { - AllReduceStrategyType strategy = AllReduceStrategyType::NCCL; - if (type != onnxruntime::DataTypeImpl::GetType() && - type != onnxruntime::DataTypeImpl::GetType()) { - return strategy; - } + AllReduceStrategyType strategy = AllReduceStrategyType::NCCL; + if (type != onnxruntime::DataTypeImpl::GetType() && + type != onnxruntime::DataTypeImpl::GetType()) { + return strategy; + } - if (world_size != 2 && world_size != 4 && world_size != 6 && world_size != 8) { - return strategy; - } + if (world_size != 2 && world_size != 4 && world_size != 6 && world_size != 8) { + return strategy; + } - int can_access_peer = 0; - ORT_ENFORCE(SetPeerAccess(rank, world_size, true, can_access_peer) == Status::OK()); - // If P2P is not enabled, we cannot use the custom allreduce, so default to NCCL. - if (!can_access_peer) { - return strategy; - } + int can_access_peer = 0; + ORT_ENFORCE(SetPeerAccess(rank, world_size, true, can_access_peer) == Status::OK()); + // If P2P is not enabled, we cannot use the custom allreduce, so default to NCCL. + if (!can_access_peer) { + return strategy; + } - const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size); - const size_t message_size_bytes = message_size * type->Size(); - - if (message_size_bytes <= maxWorkspaceSize) { - if (world_size <= 2) { - strategy = AllReduceStrategyType::ONESHOT; - } else if (world_size <= 4) { - if (message_size_bytes < 1 * 1000 * 1000) { - strategy = AllReduceStrategyType::ONESHOT; - } else { - strategy = AllReduceStrategyType::TWOSHOT; - } - } else { - if (message_size_bytes < 500 * 1000) { - strategy = AllReduceStrategyType::ONESHOT; - } else { - strategy = AllReduceStrategyType::TWOSHOT; - } - } - } + const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size); + const size_t message_size_bytes = message_size * type->Size(); - if (!ConfigurationSupported(strategy, message_size, world_size, type)) { - strategy = AllReduceStrategyType::NCCL; + if (message_size_bytes <= maxWorkspaceSize) { + if (world_size <= 2) { + strategy = AllReduceStrategyType::ONESHOT; + } else if (world_size <= 4) { + if (message_size_bytes < 1 * 1000 * 1000) { + strategy = AllReduceStrategyType::ONESHOT; + } else { + strategy = AllReduceStrategyType::TWOSHOT; + } + } else { + if (message_size_bytes < 500 * 1000) { + strategy = AllReduceStrategyType::ONESHOT; + } else { + strategy = AllReduceStrategyType::TWOSHOT; + } } + } - return strategy; + if (!ConfigurationSupported(strategy, message_size, world_size, type)) { + strategy = AllReduceStrategyType::NCCL; + } + + return strategy; } #endif -} // namespace collective -} // namespace cuda -} // namespace onnxruntime +} // namespace collective +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index d016d50d6c445..f3346d4513261 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -149,6 +149,7 @@ class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float_float_MLFloat16, SimplifiedLayerNor class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16_float_float, SimplifiedLayerNormalization); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, BFloat16_float_BFloat16, SimplifiedLayerNormalization); class CUDA_MS_OP_CLASS_NAME(1, Inverse); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MatMulNBits); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulNBits); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulNBits); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MatMulBnb4); @@ -363,6 +364,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 25c550874eb3b..58c94f966841b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -52,7 +52,7 @@ void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) #define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ GroupNormNHWCSumKernel \ <<>>( \ - params.skip_workspace, params.group_sum_buffer, params.src, params.skip, params.bias, \ + params.skip_workspace, params.group_sum_buffer, params.src, params.skip, params.bias, \ params.channels_per_block, params.hw_per_block, params.hw, params.hwc, params.c, \ params.channels_per_group, params.groups, params.groups_per_block, params.broadcast_skip); \ break; @@ -128,7 +128,6 @@ Status LaunchGroupNormKernel( bool use_silu, bool broadcast_skip, int channels_per_block) { - // tuning_ctx only used for ROCm EP. ORT_UNUSED_PARAMETER(tuning_ctx); diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu index de834db4b7440..026852623513b 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu @@ -19,9 +19,8 @@ namespace onnxruntime::llm { namespace kernels { namespace cutlass_kernels { -#ifdef ENABLE_BF16 -template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>; -#endif +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>; } // namespace cutlass_kernels } // namespace kernels } // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu index 97c71615ce54d..0849f6d9da042 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu @@ -19,9 +19,7 @@ namespace onnxruntime::llm { namespace kernels { namespace cutlass_kernels { -#ifdef ENABLE_BF16 template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>; -#endif } // namespace cutlass_kernels } // namespace kernels } // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu index 55beb8b9ca029..3f804b52034e1 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu @@ -146,12 +146,40 @@ template void launch_scaled_zero_point_kernel( half* scaled_zero_point, int n, int k_blocks, float default_zero_point); +template void launch_transpose_scale_kernel<__nv_bfloat16>( + cudaStream_t stream, + const __nv_bfloat16* scale, + __nv_bfloat16* transposed_scale, + int n, int k_blocks); + +template void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const __nv_bfloat16* zero_point, + const __nv_bfloat16* transposed_scale, + __nv_bfloat16* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +template void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const uint8_t* zero_point, + const __nv_bfloat16* transposed_scale, + __nv_bfloat16* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +// zero point is 4 bits packed. +template void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const uint8_t* zero_point, + const __nv_bfloat16* transposed_scale, + __nv_bfloat16* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + // CUDA kernel to unpack uint4, transpose, and pack into int8 directly __global__ void unpack_transpose_pack_uint4_to_int8_kernel_v2( const unsigned char* __restrict__ packed_weight, signed char* __restrict__ packed_transposed_weight, - int n, // original matrix rows - int k) // original matrix columns + int n, // original matrix rows + int k) // original matrix columns { // The output 'packed_transposed_weight' has dimensions k x (n/2) bytes. // Each thread processes one byte in the output. @@ -184,9 +212,9 @@ __global__ void unpack_transpose_pack_uint4_to_int8_kernel_v2( unsigned char packed_data_0 = packed_weight[packed_weight_idx_0]; signed char val_0; - if ((c_orig_0 % 2) == 0) { // If original column is even, it's the lower 4 bits + if ((c_orig_0 % 2) == 0) { // If original column is even, it's the lower 4 bits val_0 = (signed char)(packed_data_0 & 0x0f) - default_zero_point; - } else { // If original column is odd, it's the upper 4 bits + } else { // If original column is odd, it's the upper 4 bits val_0 = (signed char)(packed_data_0 >> 4) - default_zero_point; } @@ -201,9 +229,9 @@ __global__ void unpack_transpose_pack_uint4_to_int8_kernel_v2( unsigned char packed_data_1 = packed_weight[packed_weight_idx_1]; signed char val_1; - if ((c_orig_1 % 2) == 0) { // If original column is even, it's the lower 4 bits + if ((c_orig_1 % 2) == 0) { // If original column is even, it's the lower 4 bits val_1 = (signed char)(packed_data_1 & 0x0f) - default_zero_point; - } else { // If original column is odd, it's the upper 4 bits + } else { // If original column is odd, it's the upper 4 bits val_1 = (signed char)(packed_data_1 >> 4) - default_zero_point; } @@ -230,7 +258,6 @@ __global__ void transpose_uint8_matrix_and_convert_to_int8_kernel( const uint8_t* __restrict__ input, // shape: (n, k) int8_t* __restrict__ output, // shape: (k, n) int n, int k) { - int row = blockIdx.y * blockDim.y + threadIdx.y; // index in n int col = blockIdx.x * blockDim.x + threadIdx.x; // index in k @@ -246,7 +273,6 @@ void transpose_uint8_matrix_and_convert_to_int8( int8_t* output, // shape: (k, n) const uint8_t* input, // shape: (n, k) int n, int k) { - dim3 blockDim(16, 16); dim3 gridDim((k + blockDim.x - 1) / blockDim.x, (n + blockDim.y - 1) / blockDim.y); @@ -254,7 +280,6 @@ void transpose_uint8_matrix_and_convert_to_int8( transpose_uint8_matrix_and_convert_to_int8_kernel<<>>(input, output, n, k); } - } // namespace fpA_intB_gemv } // namespace kernels } // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu index b0a72a1d2506a..ba7ad755e369c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu @@ -28,4 +28,3 @@ namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer - diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index a60fa01c26a0e..2611dde238f48 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -56,223 +56,223 @@ static constexpr int WARP_SIZE = 32; // in the softmax kernel when we extend this module to support expert-choice routing. template __launch_bounds__(TPB) __global__ - void moe_softmax(const T *input, const bool *finished, T *output, const int num_cols) { - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmpStorage; + void moe_softmax(const T* input, const bool* finished, T* output, const int num_cols) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; - __shared__ float normalizing_factor; - __shared__ float float_max; + __shared__ float normalizing_factor; + __shared__ float float_max; - const int thread_row_offset = blockIdx.x * num_cols; + const int thread_row_offset = blockIdx.x * num_cols; #if CUDA_VERSION >= 12090 - ::cuda::std::plus sum; + ::cuda::std::plus sum; #else - // Deprecated on CUDA 12.9 - cub::Sum sum; + // Deprecated on CUDA 12.9 + cub::Sum sum; #endif - float threadData(-FLT_MAX); + float threadData(-FLT_MAX); - // Don't touch finished rows. - if ((finished != nullptr) && finished[blockIdx.x]) { - return; - } + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) { + return; + } - for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { - const int idx = thread_row_offset + ii; - threadData = max(static_cast(input[idx]), threadData); - } + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); - if (threadIdx.x == 0) { - float_max = maxElem; - } - __syncthreads(); + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); - threadData = 0; + threadData = 0; - for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { - const int idx = thread_row_offset + ii; - threadData += exp((static_cast(input[idx]) - float_max)); - } + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); - if (threadIdx.x == 0) { - normalizing_factor = 1.f / Z; - } - __syncthreads(); + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); - for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { - const int idx = thread_row_offset + ii; - const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; - output[idx] = T(val); - } + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = T(val); + } } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 template -__launch_bounds__(TPB) __global__ void moe_top_k(const T *, const bool *, T *, int *, int *, int, int, bool) { - // Does not support pre-Kepler architectures - ; +__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, int, bool) { + // Does not support pre-Kepler architectures + ; } #else template __launch_bounds__(TPB) __global__ - void moe_top_k(const T *inputs_after_softmax, const bool *finished, T *output, int *indices, int *source_rows, + void moe_top_k(const T* inputs_after_softmax, const bool* finished, T* output, int* indices, int* source_rows, int num_experts, int k, bool normalize_routing_weights) { - using cub_kvp = cub::KeyValuePair; - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmpStorage; - - cub_kvp thread_kvp; - cub::ArgMax arg_max; - - int num_rows = gridDim.x; - const int block_row = blockIdx.x; - - const bool should_process_row = finished ? !finished[block_row] : true; - const int thread_row_offset = blockIdx.x * num_experts; - float output_row_sum = 0.f; - for (int k_idx = 0; k_idx < k; ++k_idx) { - thread_kvp.key = 0; - thread_kvp.value = T(-1.f); - - cub_kvp inp_kvp; - for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { - const int idx = thread_row_offset + expert; - inp_kvp.key = expert; - inp_kvp.value = inputs_after_softmax[idx]; - - for (int prior_k = 0; prior_k < k_idx; ++prior_k) { - const int prior_winning_expert = indices[k * block_row + prior_k]; - - if (prior_winning_expert == expert) { - inp_kvp = thread_kvp; - } - } - - thread_kvp = arg_max(inp_kvp, thread_kvp); + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const bool should_process_row = finished ? !finished[block_row] : true; + const int thread_row_offset = blockIdx.x * num_experts; + float output_row_sum = 0.f; + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_row_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; } + } - const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); - if (threadIdx.x == 0) { - const int idx = k * block_row + k_idx; - output[idx] = result_kvp.value; - indices[idx] = should_process_row ? result_kvp.key : num_experts; - source_rows[idx] = k_idx * num_rows + block_row; + thread_kvp = arg_max(inp_kvp, thread_kvp); + } - if (normalize_routing_weights && k_idx == k - 1) { + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? result_kvp.key : num_experts; + source_rows[idx] = k_idx * num_rows + block_row; + + if (normalize_routing_weights && k_idx == k - 1) { #pragma unroll - for (int ki = 0; ki < k; ++ki) { - output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); - } - } + for (int ki = 0; ki < k; ++ki) { + output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); } - __syncthreads(); + } } + __syncthreads(); + } } #endif #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 template -__launch_bounds__(TPB) __global__ void sparse_mixer_top2(const T *, T *, int *, int *, const float) { - // Does not support pre-Kepler architectures - ; +__launch_bounds__(TPB) __global__ void sparse_mixer_top2(const T*, T*, int*, int*, const float) { + // Does not support pre-Kepler architectures + ; } #else template __launch_bounds__(TPB) __global__ - void sparse_mixer_top2(const T *inputs, T *output, int *indices, int *source_rows, const float jitter_eps) { - static constexpr int K = 2; + void sparse_mixer_top2(const T* inputs, T* output, int* indices, int* source_rows, const float jitter_eps) { + static constexpr int K = 2; - using cub_kvp = cub::KeyValuePair; - using KVBlockReduce = cub::BlockReduce; + using cub_kvp = cub::KeyValuePair; + using KVBlockReduce = cub::BlockReduce; - __shared__ float result_kvp_value[K]; - __shared__ typename KVBlockReduce::TempStorage kvTmpStorage; + __shared__ float result_kvp_value[K]; + __shared__ typename KVBlockReduce::TempStorage kvTmpStorage; - cub_kvp thread_kvp; - cub::ArgMax arg_max; + cub_kvp thread_kvp; + cub::ArgMax arg_max; - int num_rows = gridDim.x; - const int block_row = blockIdx.x; + int num_rows = gridDim.x; + const int block_row = blockIdx.x; - const int thread_row_offset = blockIdx.x * NUM_EXPERTS; + const int thread_row_offset = blockIdx.x * NUM_EXPERTS; - float factor[K]; - bool logits_mask[K]; + float factor[K]; + bool logits_mask[K]; #pragma unroll - for (int k_idx = 0; k_idx < K; ++k_idx) { - thread_kvp.key = 0; - thread_kvp.value = T(-1.f); + for (int k_idx = 0; k_idx < K; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); - cub_kvp inp_kvp; + cub_kvp inp_kvp; #pragma unroll - for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) { - const int idx = thread_row_offset + expert; - inp_kvp.key = expert; - inp_kvp.value = inputs[idx]; + for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) { + const int idx = thread_row_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs[idx]; - for (int prior_k = 0; prior_k < k_idx; ++prior_k) { - const int prior_winning_expert = indices[K * block_row + prior_k]; + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[K * block_row + prior_k]; - if (prior_winning_expert == expert) { - inp_kvp = thread_kvp; - } - } - - thread_kvp = arg_max(inp_kvp, thread_kvp); + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; } + } - const cub_kvp result_kvp = KVBlockReduce(kvTmpStorage).Reduce(thread_kvp, arg_max); - if (threadIdx.x == 0) { - const int idx = K * block_row + k_idx; - result_kvp_value[k_idx] = (float)result_kvp.value; - indices[idx] = result_kvp.key; - source_rows[idx] = k_idx * num_rows + block_row; - } - __syncthreads(); + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = KVBlockReduce(kvTmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = K * block_row + k_idx; + result_kvp_value[k_idx] = (float)result_kvp.value; + indices[idx] = result_kvp.key; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); #pragma unroll - for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) { - const int idx = thread_row_offset + expert; - factor[k_idx] = max(abs((float)inputs[idx]), result_kvp_value[k_idx]); - logits_mask[k_idx] = (result_kvp_value[k_idx] - (float)inputs[idx]) > (2 * jitter_eps * factor[k_idx]); - if (k_idx == 1 && expert == indices[K * block_row]) { - logits_mask[1] = true; - } - } + for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) { + const int idx = thread_row_offset + expert; + factor[k_idx] = max(abs((float)inputs[idx]), result_kvp_value[k_idx]); + logits_mask[k_idx] = (result_kvp_value[k_idx] - (float)inputs[idx]) > (2 * jitter_eps * factor[k_idx]); + if (k_idx == 1 && expert == indices[K * block_row]) { + logits_mask[1] = true; + } } + } #pragma unroll - for (int k_idx = 0; k_idx < K; ++k_idx) { - float row_sum(0); + for (int k_idx = 0; k_idx < K; ++k_idx) { + float row_sum(0); #pragma unroll - for (int ii = threadIdx.x; ii < NUM_EXPERTS; ii += TPB) { - const int idx = thread_row_offset + ii; - row_sum += logits_mask[k_idx] ? 0 : exp((static_cast(inputs[idx]) - result_kvp_value[k_idx])); - } + for (int ii = threadIdx.x; ii < NUM_EXPERTS; ii += TPB) { + const int idx = thread_row_offset + ii; + row_sum += logits_mask[k_idx] ? 0 : exp((static_cast(inputs[idx]) - result_kvp_value[k_idx])); + } #pragma unroll - for (int mask = NUM_EXPERTS / 2; mask > 0; mask /= 2) { - row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, NUM_EXPERTS); - } + for (int mask = NUM_EXPERTS / 2; mask > 0; mask /= 2) { + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, NUM_EXPERTS); + } - const float normalizing_factor = 1.f / row_sum; + const float normalizing_factor = 1.f / row_sum; - const int idx = K * block_row + k_idx; - if (threadIdx.x == indices[idx]) { - const int input_idx = thread_row_offset + threadIdx.x; - output[idx] = logits_mask[k_idx] ? 0 - : exp((static_cast(inputs[input_idx]) - result_kvp_value[k_idx])) * - normalizing_factor; - } + const int idx = K * block_row + k_idx; + if (threadIdx.x == indices[idx]) { + const int input_idx = thread_row_offset + threadIdx.x; + output[idx] = logits_mask[k_idx] ? 0 + : exp((static_cast(inputs[input_idx]) - result_kvp_value[k_idx])) * + normalizing_factor; } + } } #endif @@ -291,300 +291,301 @@ __launch_bounds__(TPB) __global__ */ template -__launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__ - void topk_gating_softmax(const T *input, const bool *finished, T *output, int num_rows, int *indices, - int *source_rows, int k, bool normalize_routing_weights) { - // We begin by enforcing compile time assertions and setting up compile time constants. - static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); - static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); - static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); - static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); - - // Number of bytes each thread pulls in per load - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); - static constexpr int ELTS_PER_ROW = NUM_EXPERTS; - static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; - static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; - - // Restrictions based on previous section. - static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); - static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); - static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); - static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); - - // We have NUM_EXPERTS elements per row. We specialize for small #experts - static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; - static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; - static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; - - // Restrictions for previous section. - static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); - - // ===================== From this point, we finally start computing run-time variables. ======================== - - // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. - // This, each block processes a chunk of rows. We start by computing the start row for each block. - const int cta_base_row = blockIdx.x * ROWS_PER_CTA; - - // Now, using the base row per thread block, we compute the base row per warp. - const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; - - // The threads in a warp are split into sub-groups that will work on a row. - // We compute row offset for each thread sub-group - const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; - const int thread_row = warp_base_row + thread_row_in_warp; - - // Threads with indices out of bounds should early exit here. - if (thread_row >= num_rows) - return; - const bool should_process_row = finished ? !finished[thread_row] : true; - - // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the - // row it will read. - const T *thread_row_ptr = input + thread_row * ELTS_PER_ROW; - - // Now, we compute the group each thread belong to in order to determine the first column to start loads. - const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; - const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; - const T *thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; - - // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, - // this can support all powers of 2 up to 16. - using AccessType = cutlass::AlignedArray; - - // Finally, we pull in the data from global mem - cutlass::Array row_chunk_input; - AccessType *row_chunk_vec_ptr = reinterpret_cast(&row_chunk_input); - const AccessType *vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + void topk_gating_softmax(const T* input, const bool* finished, T* output, int num_rows, int* indices, + int* source_rows, int k, bool normalize_routing_weights) { + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) + return; + const bool should_process_row = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + using AccessType = cutlass::AlignedArray; + + // Finally, we pull in the data from global mem + cutlass::Array row_chunk_input; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk_input); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); #pragma unroll - for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { - row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; - } - - using ComputeType = float; - using Converter = cutlass::NumericArrayConverter; - Converter compute_type_converter; - cutlass::Array row_chunk = compute_type_converter(row_chunk_input); - - // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just - // convert to float afterwards for the exp + sum reduction. - ComputeType thread_max = row_chunk[0]; + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + using ComputeType = float; + using Converter = cutlass::NumericArrayConverter; + Converter compute_type_converter; + cutlass::Array row_chunk = compute_type_converter(row_chunk_input); + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + ComputeType thread_max = row_chunk[0]; #pragma unroll - for (int ii = 1; ii < VPT; ++ii) { - thread_max = max(thread_max, row_chunk[ii]); - } + for (int ii = 1; ii < VPT; ++ii) { + thread_max = max(thread_max, row_chunk[ii]); + } // Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. #pragma unroll - for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); - } + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + } - // From this point, thread max in all the threads have the max within the row. - // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. - float row_sum = 0; + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; #pragma unroll - for (int ii = 0; ii < VPT; ++ii) { - row_chunk[ii] = expf(row_chunk[ii] - thread_max); - row_sum += row_chunk[ii]; - } + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } // Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. #pragma unroll - for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); - } + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + } - // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables - // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to - // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. - // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the - // argmax after computing the softmax. - const float reciprocal_row_sum = 1.f / row_sum; + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; #pragma unroll - for (int ii = 0; ii < VPT; ++ii) { - row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; - } - - // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along - // with the max index.​ - int start_col = first_elt_read_by_thread; - static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; - - float output_row_sum = 0.f; - for (int k_idx = 0; k_idx < k; ++k_idx) { - // First, each thread does the local argmax - float max_val = row_chunk[0]; - int expert = start_col; + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index.​ + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + float output_row_sum = 0.f; + for (int k_idx = 0; k_idx < k; ++k_idx) { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; #pragma unroll - for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) { + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) { #pragma unroll - for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { - float val = row_chunk[ldg * ELTS_PER_LDG + ii]; - - // No check on the experts here since columns with the smallest index are processed first and only - // updated if > (not >=) - if (val > max_val) { - max_val = val; - expert = col + ii; - } - } + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) { + max_val = val; + expert = col + ii; } + } + } // Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. // This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can // then blank out their max with -inf and the warp can run more iterations... #pragma unroll - for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); - int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); - - // We want lower indices to "win" in every thread so we break ties this way - if (other_max > max_val || (other_max == max_val && other_expert < expert)) { - max_val = other_max; - expert = other_expert; - } - } + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); + int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } - // Write the max for this k iteration to global memory. - if (thread_group_idx == 0) { - // The lead thread from each sub-group will write out the final results to global memory. (This will be a - // single) thread per row of the input/output matrices. - const int idx = k * thread_row + k_idx; - output[idx] = T(max_val); - output_row_sum = output_row_sum + static_cast(max_val); - indices[idx] = should_process_row ? expert : NUM_EXPERTS; - source_rows[idx] = k_idx * num_rows + thread_row; - - if (normalize_routing_weights && k_idx == k - 1) { + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) { + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = T(max_val); + output_row_sum = output_row_sum + static_cast(max_val); + indices[idx] = should_process_row ? expert : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + + if (normalize_routing_weights && k_idx == k - 1) { #pragma unroll - for (int ki = 0; ki < k; ++ki) { - output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); - } - } + for (int ki = 0; ki < k; ++ki) { + output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); } + } + } - // Finally, we clear the value in the thread with the current max if there is another iteration to run. - if (k_idx + 1 < k) { - const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; - const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; - - // Only the thread in the group which produced the max will reset the "winning" value to -inf. - if (thread_group_idx == thread_to_clear_in_group) { - const int offset_for_expert = expert % ELTS_PER_LDG; - // Safe to set to any negative value since row_chunk values must be between 0 and 1. - row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = ComputeType(-10000.f); - } - } + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = ComputeType(-10000.f); + } } + } } namespace detail { // Constructs some constants needed to partition the work across threads at compile time. -template struct TopkConstants { - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); - static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); - static constexpr int VECs_PER_THREAD = std::max(1, (int)EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); - static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; - static constexpr int THREADS_PER_ROW = EXPERTS / VPT; - static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +template +struct TopkConstants { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = std::max(1, (int)EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; }; -} // namespace detail +} // namespace detail template -void topk_gating_softmax_launcher_helper(const T *input, const bool *finished, T *output, int *indices, int *source_row, +void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row, int num_rows, int /*num_experts*/, int k, bool normalize_routing_weights, cudaStream_t stream) { - static constexpr unsigned long MAX_BYTES_PER_LDG = 16; - - static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); - using Constants = detail::TopkConstants; - static constexpr int VPT = Constants::VPT; - static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; - const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; - const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; - - dim3 block_dim(WARP_SIZE, WARPS_PER_TB); - topk_gating_softmax<<>>( - input, finished, output, num_rows, indices, source_row, k, normalize_routing_weights); + static constexpr unsigned long MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topk_gating_softmax<<>>( + input, finished, output, num_rows, indices, source_row, k, normalize_routing_weights); } template -void topk_gating_softmax_kernelLauncher(const T *input, const bool *finished, T *output, T *softmax_temp_output, - int *indices, int *source_row, int num_rows, int num_experts, int k, +void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_output, + int* indices, int* source_row, int num_rows, int num_experts, int k, bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream) { - static constexpr int WARPS_PER_TB = 4; + static constexpr int WARPS_PER_TB = 4; - if (use_sparse_mixer) { - static constexpr int TPB = WARP_SIZE * WARPS_PER_TB; - static constexpr float jitter_eps = 0.01f; + if (use_sparse_mixer) { + static constexpr int TPB = WARP_SIZE * WARPS_PER_TB; + static constexpr float jitter_eps = 0.01f; - switch (num_experts) { - case 8: { - sparse_mixer_top2<<>>(input, output, indices, source_row, jitter_eps); - break; - } - case 16: { - sparse_mixer_top2<<>>(input, output, indices, source_row, jitter_eps); - break; - } + switch (num_experts) { + case 8: { + sparse_mixer_top2<<>>(input, output, indices, source_row, jitter_eps); + break; + } + case 16: { + sparse_mixer_top2<<>>(input, output, indices, source_row, jitter_eps); + break; + } - default: { - ORT_THROW("Sparse mixer only supports 8 and 16 experts"); - } - } - return; + default: { + ORT_THROW("Sparse mixer only supports 8 and 16 experts"); + } } + return; + } - switch (num_experts) { + switch (num_experts) { case 2: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, normalize_routing_weights, stream); + break; } case 4: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, normalize_routing_weights, stream); + break; } case 8: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, normalize_routing_weights, stream); + break; } case 16: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, normalize_routing_weights, stream); + break; } case 32: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, normalize_routing_weights, stream); + break; } case 64: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, normalize_routing_weights, stream); + break; } case 128: { - topk_gating_softmax_launcher_helper( - input, finished, output, indices, source_row, num_rows, num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper( + input, finished, output, indices, source_row, num_rows, num_experts, k, normalize_routing_weights, stream); + break; } case 256: { - topk_gating_softmax_launcher_helper( - input, finished, output, indices, source_row, num_rows, num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper( + input, finished, output, indices, source_row, num_rows, num_experts, k, normalize_routing_weights, stream); + break; } default: { - static constexpr int TPB = 256; - moe_softmax<<>>(input, finished, softmax_temp_output, num_experts); - moe_top_k<<>>(softmax_temp_output, finished, output, indices, source_row, - num_experts, k, normalize_routing_weights); - } + static constexpr int TPB = 256; + moe_softmax<<>>(input, finished, softmax_temp_output, num_experts); + moe_top_k<<>>(softmax_temp_output, finished, output, indices, source_row, + num_experts, k, normalize_routing_weights); } + } } // ========================== CUB Sorting things ==================================== @@ -594,188 +595,193 @@ CubKeyValueSorter::CubKeyValueSorter(int num_experts) : num_experts_(num_experts), num_bits_((int)log2(num_experts) + 1) {} void CubKeyValueSorter::update_num_experts(int num_experts) { - num_experts_ = num_experts; - num_bits_ = (int)log2(num_experts) + 1; + num_experts_ = num_experts; + num_bits_ = (int)log2(num_experts) + 1; } size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs) { - num_key_value_pairs_ = num_key_value_pairs; - size_t required_storage = 0; - int *null_int = nullptr; - cub::DeviceRadixSort::SortPairs(NULL, required_storage, null_int, null_int, null_int, null_int, - (int)num_key_value_pairs, 0, num_bits_); - return required_storage; + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int* null_int = nullptr; + cub::DeviceRadixSort::SortPairs(NULL, required_storage, null_int, null_int, null_int, null_int, + (int)num_key_value_pairs, 0, num_bits_); + return required_storage; } -void CubKeyValueSorter::run(void *workspace, const size_t workspace_size, const int *keys_in, int *keys_out, - const int *values_in, int *values_out, const size_t num_key_value_pairs, +void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, const int* keys_in, int* keys_out, + const int* values_in, int* values_out, const size_t num_key_value_pairs, cudaStream_t stream) { - size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); - size_t actual_ws_size = workspace_size; - - if (expected_ws_size > workspace_size) { - ORT_THROW( - "Error. The allocated workspace is too small to run this problem. Expected workspace size of at least ", - expected_ws_size, " but got problem size ", workspace_size, "\n"); - } - cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out, values_in, values_out, - (int)num_key_value_pairs, 0, num_bits_, stream); + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); + size_t actual_ws_size = workspace_size; + + if (expected_ws_size > workspace_size) { + ORT_THROW( + "Error. The allocated workspace is too small to run this problem. Expected workspace size of at least ", + expected_ws_size, " but got problem size ", workspace_size, "\n"); + } + cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out, values_in, values_out, + (int)num_key_value_pairs, 0, num_bits_, stream); } // ============================== Infer GEMM sizes ================================= -__device__ inline int find_total_elts_leq_target(const int *sorted_indices, const int arr_length, const int target) { - int64_t low = 0, high = arr_length - 1, target_location = -1; - while (low <= high) { - int64_t mid = (low + high) / 2; - - if (sorted_indices[mid] > target) { - high = mid - 1; - } else { - low = mid + 1; - target_location = mid; - } +__device__ inline int find_total_elts_leq_target(const int* sorted_indices, const int arr_length, const int target) { + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) { + high = mid - 1; + } else { + low = mid + 1; + target_location = mid; } - return target_location + 1; + } + return target_location + 1; } // Sets up the gemm assuming the inputs, experts and outputs are stored in row major order. // Assumes we want to perform output = matmul(inputs, experts) + bias -__global__ void compute_total_rows_before_expert_kernel(const int *sorted_experts, const int sorted_experts_len, - const int64_t num_experts, int64_t *total_rows_before_expert) { - // First, compute the global tid. We only need 1 thread per expert. - const int expert = blockIdx.x * blockDim.x + threadIdx.x; - if (expert >= num_experts) - return; - - // This should construct the last index where each expert occurs. - total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); +__global__ void compute_total_rows_before_expert_kernel(const int* sorted_experts, const int sorted_experts_len, + const int64_t num_experts, int64_t* total_rows_before_expert) { + // First, compute the global tid. We only need 1 thread per expert. + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) + return; + + // This should construct the last index where each expert occurs. + total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); } -__global__ void dispatch_activations_kernel(int64_t *total_rows_before_expert, int num_experts, int local_num_experts, +__global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, int num_experts, int local_num_experts, int local_experts_start_index) { - const int expert = blockIdx.x * blockDim.x + threadIdx.x; - const int local_experts_end_index = local_experts_start_index + local_num_experts - 1; + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + const int local_experts_end_index = local_experts_start_index + local_num_experts - 1; - int total_past_rows = 0; - if (local_experts_start_index > 0) { - total_past_rows = total_rows_before_expert[local_experts_start_index - 1]; - } + int total_past_rows = 0; + if (local_experts_start_index > 0) { + total_past_rows = total_rows_before_expert[local_experts_start_index - 1]; + } - if (expert < local_experts_start_index || expert > local_experts_end_index) { - return; - } + if (expert < local_experts_start_index || expert > local_experts_end_index) { + return; + } - total_rows_before_expert[expert] -= total_past_rows; + total_rows_before_expert[expert] -= total_past_rows; } template CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer) - : has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0), - normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) { - moe_gemm_runner_.initialize(sm_version); + : has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0), normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) { + moe_gemm_runner_.initialize(sm_version); } template size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_rows, const size_t hidden_size, const size_t inter_size, size_t num_experts, size_t k) { - total_covered_rows_ = k * num_rows; - - const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size); - const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size); - const size_t padded_experts = pad_to_multiple_of_16(num_experts); - const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); - size_t num_softmax_outs = 0; - - const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); - if (!is_pow_2 || num_experts > 256) { - num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts); - } - - // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them - // in Encoder or Decoder before invoking FfnLayer forward. - size_t total_ws_bytes = 3 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ - total_ws_bytes += buf_size * sizeof(T); // permuted_data - total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ - total_ws_bytes += num_softmax_outs * sizeof(T); - const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); - const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)); - sorter_.update_num_experts(static_cast(num_experts)); - - size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; - if (sorter_ws_size_bytes > bytes_for_fc1_result) { - size_t remaining_bytes = pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result); - bytes_for_intermediate_and_sorting += remaining_bytes; - } - - total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace - return total_ws_bytes; + total_covered_rows_ = k * num_rows; + + const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size); + const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size); + const size_t padded_experts = pad_to_multiple_of_16(num_experts); + const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); + size_t num_softmax_outs = 0; + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts); + } + + // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them + // in Encoder or Decoder before invoking FfnLayer forward. + size_t total_ws_bytes = 3 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + total_ws_bytes += buf_size * sizeof(T); // permuted_data + total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ + total_ws_bytes += num_softmax_outs * sizeof(T); + const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); + const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)); + sorter_.update_num_experts(static_cast(num_experts)); + + size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + if (sorter_ws_size_bytes > bytes_for_fc1_result) { + size_t remaining_bytes = pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result); + bytes_for_intermediate_and_sorting += remaining_bytes; + } + + total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace + return total_ws_bytes; } template -void CutlassMoeFCRunner::configure_ws_ptrs(char *ws_ptr, size_t num_rows, +void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, size_t num_rows, const size_t hidden_size, const size_t inter_size, size_t num_experts, size_t k) { - const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size); - const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size); - const size_t padded_experts = pad_to_multiple_of_16(num_experts); - const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); - - source_rows_ = reinterpret_cast(ws_ptr); - permuted_rows_ = source_rows_ + num_moe_inputs; - permuted_experts_ = permuted_rows_ + num_moe_inputs; - permuted_data_ = reinterpret_cast(permuted_experts_ + num_moe_inputs); - - total_rows_before_expert_ = reinterpret_cast(permuted_data_ + buf_size); - - if (has_fc3_) { - fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); - fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size); - } else { - fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); - } - - const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); - if (!is_pow_2 || num_experts > 256) { - softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); - } else { - softmax_out_ = nullptr; - } + const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size); + const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size); + const size_t padded_experts = pad_to_multiple_of_16(num_experts); + const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); + + source_rows_ = reinterpret_cast(ws_ptr); + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + permuted_data_ = reinterpret_cast(permuted_experts_ + num_moe_inputs); + + total_rows_before_expert_ = reinterpret_cast(permuted_data_ + buf_size); + + if (has_fc3_) { + fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size); + } else { + fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + } + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); + } else { + softmax_out_ = nullptr; + } } namespace { struct __align__(8) Half4 { - half2 x; - half2 y; + half2 x; + half2 y; }; // TODO(wy): move to common header -template struct T4; -template <> struct T4 { - using Type = float4; +template +struct T4; +template <> +struct T4 { + using Type = float4; }; -template <> struct T4 { - using Type = Half4; +template <> +struct T4 { + using Type = Half4; }; -template struct T2; -template <> struct T2 { - using Type = float2; +template +struct T2; +template <> +struct T2 { + using Type = float2; }; -template <> struct T2 { - using Type = half2; +template <> +struct T2 { + using Type = half2; }; inline __device__ float2 operator*(const float2 a, const float2 b) { return make_float2(a.x * b.x, a.y * b.y); } inline __device__ float4 operator*(const float4 a, const float4 b) { - return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); + return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } // TODO(wy): use cuda common header and investigate pipeline build issue. -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) inline __device__ half operator*(const half a, const half b) { return __float2half(__half2float(a) * __half2float(b)); } @@ -784,209 +790,210 @@ inline __device__ half2 operator*(const half2 a, const half2 b) { return make_ha // TODO(wy): use cuda common header and investigate pipeline build issue. inline __device__ Half4 operator*(const Half4 a, const Half4 b) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) - Half4 result; - result.x = a.x * b.x; - result.y = a.y * b.y; - return result; + Half4 result; + result.x = a.x * b.x; + result.y = a.y * b.y; + return result; #else - return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; + return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; #endif } -} // anonymous namespace +} // anonymous namespace -template __global__ void elementWiseMulKernel(T *output, T const *input, size_t inter_size) { - int const tid = threadIdx.x; - int const token = blockIdx.x; - - output = output + token * inter_size; - input = input + token * inter_size; - for (int i = tid; i < inter_size; i += blockDim.x) { - T fc1_value = input[i]; - output[i] = fc1_value * output[i]; - } +template +__global__ void elementWiseMulKernel(T* output, T const* input, size_t inter_size) { + int const tid = threadIdx.x; + int const token = blockIdx.x; + + output = output + token * inter_size; + input = input + token * inter_size; + for (int i = tid; i < inter_size; i += blockDim.x) { + T fc1_value = input[i]; + output[i] = fc1_value * output[i]; + } } template -void elementWiseMul(T *output, T const *input, int inter_size, int num_tokens, cudaStream_t stream) { - int const blocks = num_tokens; - - if (inter_size & 3 == 0) { - using vec_type = typename T4::Type; - int const threads = std::min(inter_size / 4, 1024); - elementWiseMulKernel<<>>( - reinterpret_cast(output), reinterpret_cast(input), inter_size / 4); - } else if (inter_size & 1 == 0) { - using vec_type = typename T2::Type; - int const threads = std::min(inter_size / 2, 1024); - elementWiseMulKernel<<>>( - reinterpret_cast(output), reinterpret_cast(input), inter_size / 2); - } else { - int const threads = std::min(inter_size, 1024); - elementWiseMulKernel<<>>(output, input, inter_size); - } +void elementWiseMul(T* output, T const* input, int inter_size, int num_tokens, cudaStream_t stream) { + int const blocks = num_tokens; + + if (inter_size & 3 == 0) { + using vec_type = typename T4::Type; + int const threads = std::min(inter_size / 4, 1024); + elementWiseMulKernel<<>>( + reinterpret_cast(output), reinterpret_cast(input), inter_size / 4); + } else if (inter_size & 1 == 0) { + using vec_type = typename T2::Type; + int const threads = std::min(inter_size / 2, 1024); + elementWiseMulKernel<<>>( + reinterpret_cast(output), reinterpret_cast(input), inter_size / 2); + } else { + int const threads = std::min(inter_size, 1024); + elementWiseMulKernel<<>>(output, input, inter_size); + } } template void CutlassMoeFCRunner::run_moe_fc( - const T *input_activations, const T *gating_output, const WeightType *fc1_expert_weights, const T *fc1_scales, - const T *fc1_expert_biases, ActivationType fc1_activation_type, const WeightType *fc3_expert_weights, - const T *fc3_scales, const T *fc3_expert_biases, const WeightType *fc2_expert_weights, const T *fc2_scales, + const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, + const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc3_expert_weights, + const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts, - int local_experts_start_index, int k, char *workspace_ptr, T *fc2_result, const bool *finished, int active_rows, - T *expert_scales, int *expanded_source_row_to_expanded_dest_row, int *expert_for_source_row, cudaStream_t stream) { - static constexpr bool scales_required = - std::is_same::value || std::is_same::value; - + int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, + T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { + static constexpr bool scales_required = + std::is_same::value || std::is_same::value; + + if (scales_required) { + if (fc1_scales == nullptr) { + ORT_THROW("[Run MoE FC] Scales expected but scale for first matmul is a null pointer"); + } else if (fc2_scales == nullptr) { + ORT_THROW("[Run MoE FC] Scales expected but scale for second matmul is a null pointer"); + } + } else { + if (fc1_scales != nullptr) { + ORT_THROW("[Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC1"); + } else if (fc2_scales != nullptr) { + ORT_THROW("[Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC2"); + } + } + + configure_ws_ptrs(workspace_ptr, static_cast(num_rows), static_cast(hidden_size), + static_cast(inter_size), static_cast(num_experts), static_cast(k)); + topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, + source_rows_, num_rows, num_experts, k, normalize_routing_weights_, + use_sparse_mixer_, stream); + + const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); + sorter_.run(reinterpret_cast(fc1_result_), sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, + source_rows_, permuted_rows_, k * num_rows, stream); + + initialize_moe_routing_kernelLauncher(input_activations, permuted_data_, permuted_rows_, + expanded_source_row_to_expanded_dest_row, num_rows, active_rows, hidden_size, + k, stream); + + const int expanded_active_expert_rows = k * active_rows; + compute_total_rows_before_expert(permuted_experts_, expanded_active_expert_rows, num_experts, + total_rows_before_expert_, stream); + + if (local_num_experts < num_experts) { + dispatch_activations(total_rows_before_expert_, num_experts, local_num_experts, local_experts_start_index, + stream); + } + + // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, + // expanded_active_expert_rows); + moe_gemm_runner_.moe_gemm_bias_act( + permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases, + fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, inter_size, hidden_size, local_num_experts, fc1_activation_type, stream); + + if (has_fc3_) { if (scales_required) { - if (fc1_scales == nullptr) { - ORT_THROW("[Run MoE FC] Scales expected but scale for first matmul is a null pointer"); - } else if (fc2_scales == nullptr) { - ORT_THROW("[Run MoE FC] Scales expected but scale for second matmul is a null pointer"); - } + if (fc3_scales == nullptr) { + ORT_THROW("[Run MoE FC] Scales expected but scale for third matmul is a null pointer"); + } } else { - if (fc1_scales != nullptr) { - ORT_THROW("[Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC1"); - } else if (fc2_scales != nullptr) { - ORT_THROW("[Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC2"); - } + if (fc3_scales != nullptr) { + ORT_THROW("[Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC3"); + } } - - configure_ws_ptrs(workspace_ptr, static_cast(num_rows), static_cast(hidden_size), - static_cast(inter_size), static_cast(num_experts), static_cast(k)); - topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, - source_rows_, num_rows, num_experts, k, normalize_routing_weights_, - use_sparse_mixer_, stream); - - const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); - sorter_.run(reinterpret_cast(fc1_result_), sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, - source_rows_, permuted_rows_, k * num_rows, stream); - - initialize_moe_routing_kernelLauncher(input_activations, permuted_data_, permuted_rows_, - expanded_source_row_to_expanded_dest_row, num_rows, active_rows, hidden_size, - k, stream); - - const int expanded_active_expert_rows = k * active_rows; - compute_total_rows_before_expert(permuted_experts_, expanded_active_expert_rows, num_experts, - total_rows_before_expert_, stream); - - if (local_num_experts < num_experts) { - dispatch_activations(total_rows_before_expert_, num_experts, local_num_experts, local_experts_start_index, - stream); + if (fc3_expert_weights == nullptr) { + ORT_THROW("[Run MoE FC] FC3 weights are null"); } + moe_gemm_runner_.moe_gemm(permuted_data_ + total_past_rows_ * hidden_size, fc3_expert_weights, fc3_scales, + fc3_expert_biases, fc3_result_ + total_past_rows_ * inter_size, + total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, + inter_size, hidden_size, local_num_experts, stream); - // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, - // expanded_active_expert_rows); - moe_gemm_runner_.moe_gemm_bias_act( - permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases, - fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index, - expanded_active_expert_rows, inter_size, hidden_size, local_num_experts, fc1_activation_type, stream); - - if (has_fc3_) { - if (scales_required) { - if (fc3_scales == nullptr) { - ORT_THROW("[Run MoE FC] Scales expected but scale for third matmul is a null pointer"); - } - } else { - if (fc3_scales != nullptr) { - ORT_THROW("[Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC3"); - } - } - if (fc3_expert_weights == nullptr) { - ORT_THROW("[Run MoE FC] FC3 weights are null"); - } - moe_gemm_runner_.moe_gemm(permuted_data_ + total_past_rows_ * hidden_size, fc3_expert_weights, fc3_scales, - fc3_expert_biases, fc3_result_ + total_past_rows_ * inter_size, - total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, - inter_size, hidden_size, local_num_experts, stream); - - elementWiseMul(fc1_result_ + total_past_rows_ * inter_size, fc3_result_ + total_past_rows_ * inter_size, - static_cast(inter_size), static_cast(total_covered_rows_), stream); - } + elementWiseMul(fc1_result_ + total_past_rows_ * inter_size, fc3_result_ + total_past_rows_ * inter_size, + static_cast(inter_size), static_cast(total_covered_rows_), stream); + } - moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size, fc2_expert_weights, fc2_scales, nullptr, - fc2_result + total_past_rows_ * hidden_size, - total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, - hidden_size, inter_size, local_num_experts, stream); + moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size, fc2_expert_weights, fc2_scales, nullptr, + fc2_result + total_past_rows_ * hidden_size, + total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, + hidden_size, inter_size, local_num_experts, stream); } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 template -void CutlassMoeFCRunner::run_moe_fc(const T *, const T *, const WeightType *, const T *, - const T *, ActivationType, const WeightType *, const T *, - const T *, const WeightType *, const T *, int, const int, - const int, int, int, int, int k, char *, T *, T *, int *, - int *, cudaStream_t) { - // MoE gemm only supports Volta+ architectures - ORT_THROW("[Run MoE FC] MoE gemm only supports Volta+ architectures"); +void CutlassMoeFCRunner::run_moe_fc(const T*, const T*, const WeightType*, const T*, + const T*, ActivationType, const WeightType*, const T*, + const T*, const WeightType*, const T*, int, const int, + const int, int, int, int, int k, char*, T*, T*, int*, + int*, cudaStream_t) { + // MoE gemm only supports Volta+ architectures + ORT_THROW("[Run MoE FC] MoE gemm only supports Volta+ architectures"); } #else template void CutlassMoeFCRunner::run_moe_fc( - const T *input_activations, const T *gating_output, const WeightType *fc1_expert_weights, const T *fc1_scales, - const T *fc1_expert_biases, ActivationType fc1_activation_type, const WeightType *fc3_expert_weights, - const T *fc3_scales, const T *fc3_expert_biases, const WeightType *fc2_expert_weights, const T *fc2_scales, + const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, + const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc3_expert_weights, + const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts, - int local_experts_start_index, int k, char *workspace_ptr, T *fc2_result, T *expert_scales, - int *expanded_source_row_to_expanded_dest_row, int *expert_for_source_row, cudaStream_t stream) { - run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type, - fc3_expert_weights, fc3_scales, fc3_expert_biases, fc2_expert_weights, fc2_scales, num_rows, hidden_size, - inter_size, num_experts, local_num_experts, local_experts_start_index, k, workspace_ptr, fc2_result, - nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, - stream); + int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales, + int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { + run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type, + fc3_expert_weights, fc3_scales, fc3_expert_biases, fc2_expert_weights, fc2_scales, num_rows, hidden_size, + inter_size, num_experts, local_num_experts, local_experts_start_index, k, workspace_ptr, fc2_result, + nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, + stream); } #endif template -void CutlassMoeFCRunner::compute_total_rows_before_expert(const int *sorted_indices, +void CutlassMoeFCRunner::compute_total_rows_before_expert(const int* sorted_indices, const int total_indices, int num_experts, - int64_t *total_rows_before_expert, + int64_t* total_rows_before_expert, cudaStream_t stream) { - const int threads = std::min(1024, num_experts); - const int blocks = (num_experts + threads - 1) / threads; + const int threads = std::min(1024, num_experts); + const int blocks = (num_experts + threads - 1) / threads; - compute_total_rows_before_expert_kernel<<>>(sorted_indices, total_indices, num_experts, - total_rows_before_expert); + compute_total_rows_before_expert_kernel<<>>(sorted_indices, total_indices, num_experts, + total_rows_before_expert); } template -void CutlassMoeFCRunner::dispatch_activations(int64_t *total_rows_before_expert, int num_experts, +void CutlassMoeFCRunner::dispatch_activations(int64_t* total_rows_before_expert, int num_experts, int local_num_experts, int local_experts_start_index, cudaStream_t stream) { - total_rows_before_expert_host_.resize(num_experts); - cudaMemcpyAsync(total_rows_before_expert_host_.data(), total_rows_before_expert, num_experts * sizeof(int64_t), - cudaMemcpyDeviceToHost, stream); + total_rows_before_expert_host_.resize(num_experts); + cudaMemcpyAsync(total_rows_before_expert_host_.data(), total_rows_before_expert, num_experts * sizeof(int64_t), + cudaMemcpyDeviceToHost, stream); - const int threads = std::min(1024, num_experts); - const int blocks = (num_experts + threads - 1) / threads; + const int threads = std::min(1024, num_experts); + const int blocks = (num_experts + threads - 1) / threads; - cudaEvent_t ©_event = cuda_event_.Get(); - cudaEventCreateWithFlags(©_event, cudaEventDisableTiming); - cudaEventRecord(copy_event, stream); + cudaEvent_t& copy_event = cuda_event_.Get(); + cudaEventCreateWithFlags(©_event, cudaEventDisableTiming); + cudaEventRecord(copy_event, stream); - dispatch_activations_kernel<<>>(total_rows_before_expert, num_experts, - local_num_experts, local_experts_start_index); + dispatch_activations_kernel<<>>(total_rows_before_expert, num_experts, + local_num_experts, local_experts_start_index); - get_total_rows_info(local_experts_start_index, local_num_experts, total_past_rows_, total_covered_rows_); + get_total_rows_info(local_experts_start_index, local_num_experts, total_past_rows_, total_covered_rows_); } template void CutlassMoeFCRunner::get_total_rows_info(int64_t experts_start_index, - int64_t local_num_experts, int64_t &total_past_rows, - int64_t &total_covered_rows) { - int64_t experts_end_index = experts_start_index + local_num_experts - 1; - total_past_rows = 0; + int64_t local_num_experts, int64_t& total_past_rows, + int64_t& total_covered_rows) { + int64_t experts_end_index = experts_start_index + local_num_experts - 1; + total_past_rows = 0; - cudaEventSynchronize(cuda_event_.Get()); + cudaEventSynchronize(cuda_event_.Get()); - if (experts_start_index > 0) { - total_past_rows = total_rows_before_expert_host_[experts_start_index - 1]; - } - total_covered_rows = total_rows_before_expert_host_[experts_end_index] - total_past_rows; + if (experts_start_index > 0) { + total_past_rows = total_rows_before_expert_host_[experts_start_index - 1]; + } + total_covered_rows = total_rows_before_expert_host_[experts_end_index] - total_past_rows; } // ========================== Permutation things ======================================= @@ -1003,149 +1010,149 @@ void CutlassMoeFCRunner::get_total_rows_info(int64_t expe // simply take the modulus of the expanded index. template -__global__ void initialize_moe_routing_kernel(const T *unpermuted_input, T *permuted_output, - const int *expanded_dest_row_to_expanded_source_row, - int *expanded_source_row_to_expanded_dest_row, int num_rows, +__global__ void initialize_moe_routing_kernel(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, int num_rows, int active_rows, int cols) { - // Reverse permutation map. - // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need - // the reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in - // MoE. 1 thread block will be responsible for all k summations. - const int expanded_dest_row = blockIdx.x; - const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; - if (threadIdx.x == 0) { - expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row; - } - - if (blockIdx.x < active_rows) { - // Duplicate and permute rows - const int source_row = expanded_source_row % num_rows; + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need + // the reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in + // MoE. 1 thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row; + } + + if (blockIdx.x < active_rows) { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr = permuted_output + expanded_dest_row * cols; - const T *source_row_ptr = unpermuted_input + source_row * cols; - T *dest_row_ptr = permuted_output + expanded_dest_row * cols; - - for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { - dest_row_ptr[tid] = source_row_ptr[tid]; - } + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + dest_row_ptr[tid] = source_row_ptr[tid]; } + } } template -void initialize_moe_routing_kernelLauncher(const T *unpermuted_input, T *permuted_output, - const int *expanded_dest_row_to_expanded_source_row, - int *expanded_source_row_to_expanded_dest_row, int num_rows, int active_rows, +void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, int num_rows, int active_rows, int cols, int k, cudaStream_t stream) { - const int blocks = num_rows * k; - const int threads = std::min(cols, 1024); - initialize_moe_routing_kernel - <<>>(unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, - expanded_source_row_to_expanded_dest_row, num_rows, k * active_rows, cols); + const int blocks = num_rows * k; + const int threads = std::min(cols, 1024); + initialize_moe_routing_kernel + <<>>(unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, num_rows, k * active_rows, cols); } // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 template -__global__ void finalize_moe_routing_kernel(const T *, T *, const T *, const T *, const T *, const T *, const int *, - const int *, int, int) { - // Does not support pre-Kepler architectures - ; +__global__ void finalize_moe_routing_kernel(const T*, T*, const T*, const T*, const T*, const T*, const int*, + const int*, int, int) { + // Does not support pre-Kepler architectures + ; } #else template -__global__ void finalize_moe_routing_kernel(const T *expanded_permuted_rows, T *reduced_unpermuted_output, - const T *skip_1, const T *skip_2, const T *bias, const T *scales, - const int *expanded_source_row_to_expanded_dest_row, - const int *expert_for_source_row, int cols, int k) { - const int original_row = blockIdx.x; - int num_rows = gridDim.x; - T *reduced_row_ptr = reduced_unpermuted_output + original_row * cols; - - const T *skip_1_row_ptr = nullptr; - if (RESIDUAL_NUM == 1) { - skip_1_row_ptr = skip_1 + original_row * cols; - } - const T *skip_2_row_ptr = nullptr; - if (RESIDUAL_NUM == 2) { - skip_2_row_ptr = skip_2 + original_row * cols; +__global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, T* reduced_unpermuted_output, + const T* skip_1, const T* skip_2, const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int cols, int k) { + const int original_row = blockIdx.x; + int num_rows = gridDim.x; + T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; + + const T* skip_1_row_ptr = nullptr; + if (RESIDUAL_NUM == 1) { + skip_1_row_ptr = skip_1 + original_row * cols; + } + const T* skip_2_row_ptr = nullptr; + if (RESIDUAL_NUM == 2) { + skip_2_row_ptr = skip_2 + original_row * cols; + } + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + T thread_output; + if (RESIDUAL_NUM == 0) { + thread_output = T(0); + } else if (RESIDUAL_NUM == 1) { + thread_output = skip_1_row_ptr[tid]; + } else if (RESIDUAL_NUM == 2) { + thread_output = skip_1_row_ptr[tid] + skip_2_row_ptr[tid]; } + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; - for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { - T thread_output; - if (RESIDUAL_NUM == 0) { - thread_output = T(0); - } else if (RESIDUAL_NUM == 1) { - thread_output = skip_1_row_ptr[tid]; - } else if (RESIDUAL_NUM == 2) { - thread_output = skip_1_row_ptr[tid] + skip_2_row_ptr[tid]; - } - for (int k_idx = 0; k_idx < k; ++k_idx) { - const int expanded_original_row = original_row + k_idx * num_rows; - const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; - - const int64_t k_offset = original_row * k + k_idx; - const T row_scale = scales[k_offset]; - const T *expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; + const int64_t k_offset = original_row * k + k_idx; + const T row_scale = scales[k_offset]; + const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; - const int expert_idx = expert_for_source_row[k_offset]; - const T *bias_ptr = bias ? bias + expert_idx * cols : nullptr; + const int expert_idx = expert_for_source_row[k_offset]; + const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; - thread_output = - thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + (bias_ptr ? bias_ptr[tid] : T(0))); - } - reduced_row_ptr[tid] = thread_output; + thread_output = + thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + (bias_ptr ? bias_ptr[tid] : T(0))); } + reduced_row_ptr[tid] = thread_output; + } } #endif template -void finalize_moe_routing_kernelLauncher(const T *expanded_permuted_rows, T *reduced_unpermuted_output, const T *bias, - const T *scales, const int *expanded_source_row_to_expanded_dest_row, - const int *expert_for_source_row, int num_rows, int cols, int k, +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, + const T* scales, const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, int k, cudaStream_t stream) { - const int blocks = num_rows; - const int threads = std::min(cols, 1024); - finalize_moe_routing_kernel<<>>( - expanded_permuted_rows, reduced_unpermuted_output, nullptr, nullptr, bias, scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, nullptr, nullptr, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); } template -void finalize_moe_routing_kernelLauncher(const T *expanded_permuted_rows, T *reduced_unpermuted_output, const T *skip, - const T *bias, const T *scales, - const int *expanded_source_row_to_expanded_dest_row, - const int *expert_for_source_row, int num_rows, int cols, int k, +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip, + const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, int k, cudaStream_t stream) { - const int blocks = num_rows; - const int threads = std::min(cols, 1024); - finalize_moe_routing_kernel - <<>>(expanded_permuted_rows, reduced_unpermuted_output, skip, nullptr, bias, scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + finalize_moe_routing_kernel + <<>>(expanded_permuted_rows, reduced_unpermuted_output, skip, nullptr, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); } template -void finalize_moe_routing_kernelLauncher(const T *expanded_permuted_rows, T *reduced_unpermuted_output, const T *skip_1, - const T *skip_2, const T *bias, const T *scales, - const int *expanded_source_row_to_expanded_dest_row, - const int *expert_for_source_row, int num_rows, int cols, int k, +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, + const T* skip_2, const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, int k, cudaStream_t stream) { - const int blocks = num_rows; - const int threads = std::min(cols, 1024); - if (skip_2 == nullptr) { - finalize_moe_routing_kernel<<>>( - expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); - } else { - finalize_moe_routing_kernel<<>>( - expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); - } + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + if (skip_2 == nullptr) { + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + } else { + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + } } // ========================= TopK Softmax specializations =========================== -template void topk_gating_softmax_kernelLauncher(const float *, const bool *, float *, float *, int *, int *, int, int, +template void topk_gating_softmax_kernelLauncher(const float*, const bool*, float*, float*, int*, int*, int, int, int, bool, bool, cudaStream_t); -template void topk_gating_softmax_kernelLauncher(const half *, const bool *, half *, half *, int *, int *, int, int, +template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, int, int, bool, bool, cudaStream_t); // ==================== Variable batched GEMM specializations ================================== @@ -1155,23 +1162,23 @@ template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; // ===================== Specializations for init routing ========================= -template void initialize_moe_routing_kernelLauncher(const float *, float *, const int *, int *, int, int, int, int, +template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, int, int, cudaStream_t); -template void initialize_moe_routing_kernelLauncher(const half *, half *, const int *, int *, int, int, int, int, +template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, int, int, cudaStream_t); // ==================== Specializations for final routing =================================== -template void finalize_moe_routing_kernelLauncher(const float *, float *, const float *, const float *, const int *, - const int *, int, int, int, cudaStream_t); -template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const int *, - const int *, int, int, int, cudaStream_t); -template void finalize_moe_routing_kernelLauncher(const float *, float *, const float *, const float *, const float *, - const int *, const int *, int, int, int, cudaStream_t); -template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const half *, - const int *, const int *, int, int, int, cudaStream_t); -template void finalize_moe_routing_kernelLauncher(const float *, float *, const float *, const float *, const float *, - const float *, const int *, const int *, int, int, int, cudaStream_t); -template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const half *, - const half *, const int *, const int *, int, int, int, cudaStream_t); - -} // namespace ort_fastertransformer +template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*, + const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const int*, + const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, + const int*, const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, + const int*, const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, + const float*, const int*, const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, + const half*, const int*, const int*, int, int, int, cudaStream_t); + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index 28dce3937dd23..27fbc48bdb88c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -36,7 +36,9 @@ template < bool Columnwise> struct BlkQuantTraits { // number of qbit elements to pack into whole bytes - static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 + : (qbits == 2) ? 4 + : 0; static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu index cea1834fa1b62..cbcd4ed2f54a0 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu @@ -2,15 +2,16 @@ // Licensed under the MIT License. #include +#include #include #include #include -#include -#include +#include #include +#include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" -#include "dequantize_blockwise.cuh" +#include "contrib_ops/cuda/quantization/dequantize_blockwise.cuh" using namespace onnxruntime::cuda; using namespace cub; @@ -43,6 +44,33 @@ __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, h *(reinterpret_cast(output)) = *(reinterpret_cast(results)); } +__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, __nv_bfloat16 scale, __nv_bfloat16 zp, + __nv_bfloat16* output) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + __nv_bfloat162 scale_bf162 = __bfloat162bfloat162(scale); + __nv_bfloat16 zp_adjust = __hneg(scale) * zp; + __nv_bfloat162 zp_adjust2 = __bfloat162bfloat162(zp_adjust); + + alignas(16) __nv_bfloat162 results[4]; + __nv_bfloat16 v0 = __uint2bfloat16_rn(values_quant & 0xF); + __nv_bfloat16 v1 = __uint2bfloat16_rn((values_quant >> 4) & 0xF); + results[0] = __halves2bfloat162(v0, v1) * scale_bf162 + zp_adjust2; + + __nv_bfloat16 v2 = __uint2bfloat16_rn((values_quant >> 8) & 0xF); + __nv_bfloat16 v3 = __uint2bfloat16_rn((values_quant >> 12) & 0xF); + results[1] = __halves2bfloat162(v2, v3) * scale_bf162 + zp_adjust2; + + __nv_bfloat16 v4 = __uint2bfloat16_rn((values_quant >> 16) & 0xF); + __nv_bfloat16 v5 = __uint2bfloat16_rn((values_quant >> 20) & 0xF); + results[2] = __halves2bfloat162(v4, v5) * scale_bf162 + zp_adjust2; + + __nv_bfloat16 v6 = __uint2bfloat16_rn((values_quant >> 24) & 0xF); + __nv_bfloat16 v7 = __uint2bfloat16_rn((values_quant >> 28) & 0xF); + results[3] = __halves2bfloat162(v6, v7) * scale_bf162 + zp_adjust2; + *(reinterpret_cast(output)) = *(reinterpret_cast(results)); +#endif +} + __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, float scale, float zp, float* output) { float zp_adjust = -scale * zp; output[0] = float(values_quant & 0xF) * scale + zp_adjust; @@ -94,6 +122,9 @@ __global__ void Dequantize4BitsKernelReOrder( if constexpr (std::is_same_v) { T zp_adjust = -scale * __short2half_rn(zp); output_i[i] = __uint2half_rn((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else if constexpr (std::is_same_v) { + T zp_adjust = __hneg(scale) * __ushort2bfloat16_rn(zp); + output_i[i] = __uint2bfloat16_rn((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; } else { T zp_adjust = -scale * T(zp); output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; @@ -214,6 +245,18 @@ template Status Dequantize4Bits( int n, int block_size, cudaStream_t stream); + +template Status Dequantize4Bits<__nv_bfloat16, uint8_t>( + __nv_bfloat16* output, + const uint8_t* quant_data, + const __nv_bfloat16* scales_data, + const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + template Status Dequantize4Bits( float* output, const uint8_t* quant_data, @@ -236,6 +279,17 @@ template Status Dequantize4Bits( int block_size, cudaStream_t stream); +template Status Dequantize4Bits<__nv_bfloat16, __nv_bfloat16>( + __nv_bfloat16* output, + const uint8_t* quant_data, + const __nv_bfloat16* scales_data, + const __nv_bfloat16* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + template < typename ElementT, int32_t block_size, @@ -301,10 +355,19 @@ __global__ void dequantizeThread4b(ElementT* dst, half2 v = {__ushort2half_rn(vi & 0xF), __ushort2half_rn((vi >> 4) & 0xF)}; half2 results = v * scale_half2 + zp_adjust2; + dst[j * rows + i] = results.x; + dst[j * rows + (i + 1)] = results.y; + } else if constexpr (std::is_same::value) { + __nv_bfloat162 scale_bf162 = {scale0, scale1}; + __nv_bfloat162 zp_adjust2 = {adjust0, adjust1}; + + __nv_bfloat162 v = {__ushort2bfloat16_rn(vi & 0xF), __ushort2bfloat16_rn((vi >> 4) & 0xF)}; + __nv_bfloat162 results = v * scale_bf162 + zp_adjust2; + dst[j * rows + i] = results.x; dst[j * rows + (i + 1)] = results.y; } else { - static_assert(std::is_same::value, "Only float and half are supported!"); + static_assert(std::is_same::value, "Only float, half and bfloat16 are supported!"); const uint8_t vi0 = vi & 0xf; const uint8_t vi1 = vi >> 4; dst[j * rows + i] = static_cast(vi0) * scale0 + adjust0; @@ -429,6 +492,17 @@ template Status DequantizeBlockwise4b( int columns, cudaStream_t stream); +template Status DequantizeBlockwise4b<__nv_bfloat16>( + __nv_bfloat16* dst, + const uint8_t* src, + const __nv_bfloat16* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_8bits.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_8bits.cu index e90ed85b22f02..8c33218c81767 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_8bits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_8bits.cu @@ -2,16 +2,17 @@ // Licensed under the MIT License. #include +#include #include #include #include -#include -#include +#include #include +#include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" -#include "dequantize_blockwise.cuh" +#include "contrib_ops/cuda/quantization/dequantize_blockwise.cuh" using namespace onnxruntime::cuda; using namespace cub; @@ -42,6 +43,31 @@ __device__ __forceinline__ void DequantizeFourElements(uint32_t values_quant, ha *(reinterpret_cast(output)) = *(reinterpret_cast(results)); } +// Processes 4 elements (since each is 8 bits, 4 fit in uint32_t) for bfloat16 +__device__ __forceinline__ void DequantizeFourElements(uint32_t values_quant, __nv_bfloat16 scale, __nv_bfloat16 zp, __nv_bfloat16* output) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + __nv_bfloat162 scale_bf162 = __bfloat162bfloat162(scale); + // Formula: val = (quant - zp) * scale = quant * scale - zp * scale + __nv_bfloat16 zp_adjust = -scale * zp; + __nv_bfloat162 zp_adjust2 = __bfloat162bfloat162(zp_adjust); + + alignas(16) __nv_bfloat162 results[2]; // Store 4 bfloat16 values + + // Extract 4 uint8_t values from uint32_t + __nv_bfloat16 v0 = __uint2bfloat16_rn(static_cast(values_quant & 0xFF)); + __nv_bfloat16 v1 = __uint2bfloat16_rn(static_cast((values_quant >> 8) & 0xFF)); + results[0] = __halves2bfloat162(v0, v1) * scale_bf162 + zp_adjust2; + + __nv_bfloat16 v2 = __uint2bfloat16_rn(static_cast((values_quant >> 16) & 0xFF)); + __nv_bfloat16 v3 = __uint2bfloat16_rn(static_cast((values_quant >> 24) & 0xFF)); + results[1] = __halves2bfloat162(v2, v3) * scale_bf162 + zp_adjust2; + + // Write 4 bfloat16 values + *(reinterpret_cast<__nv_bfloat162*>(output)) = results[0]; + *(reinterpret_cast<__nv_bfloat162*>(output + 2)) = results[1]; +#endif +} + // Processes 4 elements (since each is 8 bits, 4 fit in uint32_t) __device__ __forceinline__ void DequantizeFourElements(uint32_t values_quant, float scale, float zp, float* output) { // Assuming ZP is symmetric or already adjusted if needed. Standard formula: val = (quant - zp) * scale = quant * scale - zp * scale @@ -108,6 +134,10 @@ __global__ void Dequantize8BitsKernelReOrder( T zp_T = __ushort2half_rn(zp); T zp_adjust = -scale * zp_T; output_i[i] = __ushort2half_rn(q_val) * scale + zp_adjust; + } else if constexpr (std::is_same_v) { + T zp_T = __uint2bfloat16_rn(zp); + T zp_adjust = -scale * zp_T; + output_i[i] = __uint2bfloat16_rn(q_val) * scale + zp_adjust; } else { T zp_T = static_cast(zp); T zp_adjust = -scale * zp_T; @@ -145,14 +175,16 @@ __global__ void Dequantize8BitsKernel( if (zero_points) { zp = zero_points[block_id]; // Direct lookup, no packing } - // Convert uint8_t zp to T (float/half) + // Convert uint8_t zp to T (float/half/bfloat16) if constexpr (std::is_same_v) { zero_point_value = __uint2half_rn(zp); + } else if constexpr (std::is_same_v) { + zero_point_value = __uint2bfloat16_rn(zp); } else { zero_point_value = static_cast(zp); } - } else { // ZeroT is T (float or half) - // Default 0 for float/half zero point + } else { // ZeroT is T (float, half, or bfloat16) + // Default 0 for float/half/bfloat16 zero point zero_point_value = zero_points ? *(zero_points + block_id) : static_cast(0.0f); } @@ -240,6 +272,17 @@ template Status Dequantize8Bits( int block_size, cudaStream_t stream); +template Status Dequantize8Bits<__nv_bfloat16, uint8_t>( + __nv_bfloat16* output, + const uint8_t* quant_data, + const __nv_bfloat16* scales_data, + const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + template Status Dequantize8Bits( float* output, const uint8_t* quant_data, @@ -262,6 +305,17 @@ template Status Dequantize8Bits( int block_size, cudaStream_t stream); +template Status Dequantize8Bits<__nv_bfloat16, __nv_bfloat16>( + __nv_bfloat16* output, + const uint8_t* quant_data, + const __nv_bfloat16* scales_data, + const __nv_bfloat16* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + // Generic dequantization kernel for 8 bits template < typename ElementT, @@ -345,8 +399,12 @@ __global__ void dequantizeThread8b(ElementT* dst, const half zp_half = __uint2half_rn(zp_uint8); const half adjust = -scale * zp_half; dst[q_val_idx] = __uint2half_rn(q_val) * scale + adjust; + } else if constexpr (std::is_same::value) { + const __nv_bfloat16 zp_bf16 = __uint2bfloat16_rn(zp_uint8); + const __nv_bfloat16 adjust = -scale * zp_bf16; + dst[q_val_idx] = __uint2bfloat16_rn(q_val) * scale + adjust; } else { // Float - static_assert(std::is_same::value, "Only float and half are supported!"); + static_assert(std::is_same::value, "Only float, half and bfloat16 are supported!"); const float zp_float = static_cast(zp_uint8); const float adjust = -scale * zp_float; dst[q_val_idx] = static_cast(q_val) * scale + adjust; @@ -460,6 +518,17 @@ template Status DequantizeBlockwise8b( int columns, cudaStream_t stream); +template Status DequantizeBlockwise8b<__nv_bfloat16>( + __nv_bfloat16* dst, + const uint8_t* src, + const __nv_bfloat16* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu index 2f74dd41f0759..f269c18fb5295 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu @@ -69,7 +69,7 @@ __global__ void kDequantizeBlockwise( __syncthreads(); LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128); - #pragma unroll NUM_PER_TH +#pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { vals[j * 2] = ScalarMul(quant_map[qvals[j] >> 4], local_abs_max); vals[j * 2 + 1] = ScalarMul(quant_map[qvals[j] & 0x0F], local_abs_max); @@ -140,18 +140,18 @@ Status DequantizeBnb4( int block_size, int numel, cudaStream_t stream) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 - CallkDequantizeBlockwise( - reinterpret_cast(quant_map), - reinterpret_cast(output), - quant_data, - reinterpret_cast(absmax), - block_size, - numel, - stream); - #else - CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream); - #endif +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + CallkDequantizeBlockwise( + reinterpret_cast(quant_map), + reinterpret_cast(output), + quant_data, + reinterpret_cast(absmax), + block_size, + numel, + stream); +#else + CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream); +#endif return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh index a0d38c9853cd6..241fd6f99ad74 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh @@ -22,12 +22,12 @@ __device__ inline float ScalarMul(float a, float b) { template <> __device__ inline half ScalarMul(half a, half b) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - return a * b; - #else - // half multiplication not supported - return static_cast(static_cast(a) * static_cast(b)); - #endif +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return a * b; +#else + // half multiplication not supported + return static_cast(static_cast(a) * static_cast(b)); +#endif } template <> diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu index 5d634b8a929f1..9010d7dbdbcc1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu @@ -4,10 +4,11 @@ #include #include #include +#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" -#include "matmul_nbits.cuh" +#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" using namespace onnxruntime::cuda; using namespace cub; @@ -167,6 +168,59 @@ __device__ __forceinline__ void AccumulateEightElements4b(uint32_t values_quant, sums[7] += v7 * a_vec_1.w; } +// Convert 8 4-bit integers stored in one uint32_t to 8 bfloat16s. +// The output order is [0,4], [1,5], [2,6], [3,7] +__device__ __forceinline__ void Convert8xInt4To8xBF16s(uint32_t value, __nv_bfloat162* bf16_2x4) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + const int i0 = (value >> 0) & 0xF; + const int i1 = (value >> 4) & 0xF; + const int i2 = (value >> 8) & 0xF; + const int i3 = (value >> 12) & 0xF; + const int i4 = (value >> 16) & 0xF; + const int i5 = (value >> 20) & 0xF; + const int i6 = (value >> 24) & 0xF; + const int i7 = (value >> 28) & 0xF; + + bf16_2x4[0] = __floats2bfloat162_rn(static_cast(i0), static_cast(i4)); + bf16_2x4[1] = __floats2bfloat162_rn(static_cast(i1), static_cast(i5)); + bf16_2x4[2] = __floats2bfloat162_rn(static_cast(i2), static_cast(i6)); + bf16_2x4[3] = __floats2bfloat162_rn(static_cast(i3), static_cast(i7)); +#endif +} + +__device__ __forceinline__ void AccumulateEightElements4b(uint32_t values_quant, nv_bfloat16 scale, uint8_t zp, const nv_bfloat16* a, nv_bfloat16* sums) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + __nv_bfloat162 scale_bf162 = __bfloat162bfloat162(scale); + nv_bfloat16 zp_adjust = -scale * __uint2bfloat16_rn(zp); + __nv_bfloat162 zp_adjust2 = __bfloat162bfloat162(zp_adjust); + + const uint4 vec_a = *(reinterpret_cast(a)); + + constexpr uint32_t kLowHalf2 = 0x5410; + constexpr uint32_t kHighHalf2 = 0x7632; + + uint4 vec_permuted; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(vec_permuted.x) : "r"(vec_a.x), "r"(vec_a.z), "r"(kLowHalf2)); + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(vec_permuted.y) : "r"(vec_a.x), "r"(vec_a.z), "r"(kHighHalf2)); + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(vec_permuted.z) : "r"(vec_a.y), "r"(vec_a.w), "r"(kLowHalf2)); + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(vec_permuted.w) : "r"(vec_a.y), "r"(vec_a.w), "r"(kHighHalf2)); + + __nv_bfloat162 elements[4]; // [04, 15, 26, 37] + Convert8xInt4To8xBF16s(values_quant, elements); + + __nv_bfloat162 v0 = __hfma2(elements[0], scale_bf162, zp_adjust2); + __nv_bfloat162 v1 = __hfma2(elements[1], scale_bf162, zp_adjust2); + __nv_bfloat162 v2 = __hfma2(elements[2], scale_bf162, zp_adjust2); + __nv_bfloat162 v3 = __hfma2(elements[3], scale_bf162, zp_adjust2); + + __nv_bfloat162* sums_bf162 = reinterpret_cast<__nv_bfloat162*>(sums); + sums_bf162[0] = __hfma2(v0, *reinterpret_cast(&vec_permuted.x), sums_bf162[0]); + sums_bf162[1] = __hfma2(v1, *reinterpret_cast(&vec_permuted.y), sums_bf162[1]); + sums_bf162[2] = __hfma2(v2, *reinterpret_cast(&vec_permuted.z), sums_bf162[2]); + sums_bf162[3] = __hfma2(v3, *reinterpret_cast(&vec_permuted.w), sums_bf162[3]); +#endif +} + constexpr int kColsPerThreadBlock = 8; constexpr int kElementsPerThreadPerIteration = 8; constexpr int kWarpSize = GPU_WARP_SIZE; @@ -348,6 +402,19 @@ template bool TryMatMul4Bits( size_t shared_mem_per_block, cudaStream_t stream); +template bool TryMatMul4Bits( + nv_bfloat16* output, + const nv_bfloat16* a_data, + const uint8_t* b_data_quant, + const nv_bfloat16* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + size_t shared_mem_per_block, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_8bits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_8bits.cu index 30fbb486378a8..d8a212293e23d 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_8bits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_8bits.cu @@ -4,10 +4,11 @@ #include #include #include +#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" -#include "matmul_nbits.cuh" +#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" using namespace onnxruntime::cuda; using namespace cub; @@ -23,14 +24,14 @@ constexpr int kWarpSize = GPU_WARP_SIZE; // Typically 32 constexpr uint8_t kDefaultZeroPoint = 128; // Default zero point if not provided // --- Device Function: Accumulate 8 Elements (half precision) --- -// Dequantizes 8 uint8_t values and accumulates the result with 8 half values from A. -// sums += A * dequant(B_quant) +// Dequantizes 8 uint8_t values and accumulates the result with 8 half values from A into float sums. +// sums_f += A_h * dequant(B_quant) __device__ __forceinline__ void AccumulateEightElements8b( uint64_t values_quant, // 8 packed uint8_t values from B half scale, // Dequantization scale for this block uint8_t zp, // Dequantization zero point for this block const half* a, // Pointer to 8 half values from A - half* sums) { // Pointer to 8 partial sums (half) + float* sums_f) { // Pointer to 8 partial sums (float) #if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530) // --- Dequantization Setup --- @@ -62,21 +63,33 @@ __device__ __forceinline__ void AccumulateEightElements8b( half2 b_vec3 = __hmul2(diff_67, scale_h2); // {b6, b7} // --- Load Input A (8 half values as 4 half2 vectors) --- - // Assumes 'a' is properly aligned for half2 reads. const half2* a_half2 = reinterpret_cast(a); half2 a_vec0 = a_half2[0]; // {a0, a1} half2 a_vec1 = a_half2[1]; // {a2, a3} half2 a_vec2 = a_half2[2]; // {a4, a5} half2 a_vec3 = a_half2[3]; // {a6, a7} - // --- Accumulate: sums += a * b_vec using half2 FMA --- - half2* sums_half2 = reinterpret_cast(sums); - sums_half2[0] = __hfma2(a_vec0, b_vec0, sums_half2[0]); // {s0+=a0*b0, s1+=a1*b1} - sums_half2[1] = __hfma2(a_vec1, b_vec1, sums_half2[1]); // {s2+=a2*b2, s3+=a3*b3} - sums_half2[2] = __hfma2(a_vec2, b_vec2, sums_half2[2]); // {s4+=a4*b4, s5+=a5*b5} - sums_half2[3] = __hfma2(a_vec3, b_vec3, sums_half2[3]); // {s6+=a6*b6, s7+=a7*b7} - -#else // older GPUs of compute capability < 5.3, which lacks native half support. + // Convert half2 inputs to float2 for fmaf operations on sums_f + float2 a_vec0_f = __half22float2(a_vec0); + float2 a_vec1_f = __half22float2(a_vec1); + float2 a_vec2_f = __half22float2(a_vec2); + float2 a_vec3_f = __half22float2(a_vec3); + + float2 b_vec0_f = __half22float2(b_vec0); + float2 b_vec1_f = __half22float2(b_vec1); + float2 b_vec2_f = __half22float2(b_vec2); + float2 b_vec3_f = __half22float2(b_vec3); + + sums_f[0] = fmaf(a_vec0_f.x, b_vec0_f.x, sums_f[0]); + sums_f[1] = fmaf(a_vec0_f.y, b_vec0_f.y, sums_f[1]); + sums_f[2] = fmaf(a_vec1_f.x, b_vec1_f.x, sums_f[2]); + sums_f[3] = fmaf(a_vec1_f.y, b_vec1_f.y, sums_f[3]); + sums_f[4] = fmaf(a_vec2_f.x, b_vec2_f.x, sums_f[4]); + sums_f[5] = fmaf(a_vec2_f.y, b_vec2_f.y, sums_f[5]); + sums_f[6] = fmaf(a_vec3_f.x, b_vec3_f.x, sums_f[6]); + sums_f[7] = fmaf(a_vec3_f.y, b_vec3_f.y, sums_f[7]); + +#else // older GPUs of compute capability < 5.3, which lacks native half support. float scale_f = __half2float(scale); float zp_f = static_cast(zp); @@ -90,23 +103,20 @@ __device__ __forceinline__ void AccumulateEightElements8b( #pragma unroll for (int i = 0; i < 8; ++i) { float a_f = __half2float(a[i]); - float product_f = a_f * b_dequant[i]; - // Convert back to half for partial sums. It is not ideal for performance. - half product_h = __float2half_rn(product_f); - sums[i] += product_h; + sums_f[i] = fmaf(a_f, b_dequant[i], sums_f[i]); } #endif } // --- Device Function: Accumulate 8 Elements (float precision) --- // Dequantizes 8 uint8_t values and accumulates the result with 8 float values from A. -// sums += A * dequant(B_quant) +// sums_f += A_f * dequant(B_quant) __device__ __forceinline__ void AccumulateEightElements8b( uint64_t values_quant, // 8 packed uint8_t values from B float scale, // Dequantization scale for this block uint8_t zp, // Dequantization zero point for this block const float* a, // Pointer to 8 float values from A - float* sums) { // Pointer to 8 partial sums (float) + float* sums_f) { // Pointer to 8 partial sums (float) // Load A using float4 for potentially better memory bandwidth float4 a_vec_0 = *(reinterpret_cast(a)); @@ -125,14 +135,45 @@ __device__ __forceinline__ void AccumulateEightElements8b( } // Accumulate using fmaf (fused multiply-add) - sums[0] = fmaf(v[0], a_vec_0.x, sums[0]); - sums[1] = fmaf(v[1], a_vec_0.y, sums[1]); - sums[2] = fmaf(v[2], a_vec_0.z, sums[2]); - sums[3] = fmaf(v[3], a_vec_0.w, sums[3]); - sums[4] = fmaf(v[4], a_vec_1.x, sums[4]); - sums[5] = fmaf(v[5], a_vec_1.y, sums[5]); - sums[6] = fmaf(v[6], a_vec_1.z, sums[6]); - sums[7] = fmaf(v[7], a_vec_1.w, sums[7]); + sums_f[0] = fmaf(v[0], a_vec_0.x, sums_f[0]); + sums_f[1] = fmaf(v[1], a_vec_0.y, sums_f[1]); + sums_f[2] = fmaf(v[2], a_vec_0.z, sums_f[2]); + sums_f[3] = fmaf(v[3], a_vec_0.w, sums_f[3]); + sums_f[4] = fmaf(v[4], a_vec_1.x, sums_f[4]); + sums_f[5] = fmaf(v[5], a_vec_1.y, sums_f[5]); + sums_f[6] = fmaf(v[6], a_vec_1.z, sums_f[6]); + sums_f[7] = fmaf(v[7], a_vec_1.w, sums_f[7]); +} + +// --- Device Function: Accumulate 8 Elements (bfloat16 precision) --- +// Dequantizes 8 uint8_t values and accumulates the result with 8 nv_bfloat16 values from A. +// sums_f += A_bf16 * dequant(B_quant) +__device__ __forceinline__ void AccumulateEightElements8b( + uint64_t values_quant, // 8 packed uint8_t values from B + nv_bfloat16 scale, // Dequantization scale for this block + uint8_t zp, // Dequantization zero point for this block + const nv_bfloat16* a, // Pointer to 8 nv_bfloat16 values from A + float* sums_f) { // Pointer to 8 partial sums (float) +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + float scale_f = __bfloat162float(scale); + float zp_f = static_cast(zp); + + float zp_adjust = -scale_f * zp_f; + + float a_f[8]; + float b_dequant_f[8]; +#pragma unroll + for (int i = 0; i < 8; ++i) { + a_f[i] = __bfloat162float(a[i]); + uint8_t q_val = (values_quant >> (i * 8)) & 0xFF; + b_dequant_f[i] = static_cast(q_val) * scale_f + zp_adjust; + } + +#pragma unroll + for (int i = 0; i < 8; ++i) { + sums_f[i] = fmaf(a_f[i], b_dequant_f[i], sums_f[i]); + } +#endif } // --- CUDA Kernel: MatMulFloat8bKernel (Optimized for m=1) --- @@ -140,7 +181,7 @@ __device__ __forceinline__ void AccumulateEightElements8b( // B(K, N) is quantized with 8 bits and block_size bs, stored as [N, K/bs, bs] // // Template Parameters: -// T: Data type for A and C (float or half) +// T: Data type for A and C (float, half, or nv_bfloat16) // block_size: Quantization block size for B // has_zero_point: Boolean indicating if zero points are provided // @@ -159,7 +200,6 @@ __global__ void __launch_bounds__(kWarpSize* kColsPerThreadBlock) MatMulFloat8bK // --- Thread Indexing --- const int n_block_id = blockIdx.x; // Block column index [0, Ceil(N / kColsPerThreadBlock)) - // m_id is implicitly 0 since blockDim.y is 1 const int lane_id = threadIdx.x; // Thread index in warp (0..31) const int warp_id = threadIdx.y; // Warp index in block (0..kColsPerThreadBlock-1) @@ -227,10 +267,8 @@ __global__ void __launch_bounds__(kWarpSize* kColsPerThreadBlock) MatMulFloat8bK } // --- Accumulation --- - // Initialize partial sums for this thread to zero - // Note that partial sum uses original data type. It is a trade-off between performance and accuracy. - // For example, K=3072, each accumulates k / k_per_iter = 3072 / 256 = 12 elements. - T sums[kElementsPerThreadPerIteration] = {static_cast(0.0f)}; + // Initialize partial sums for this thread to zero. Always accumulate in float for precision. + float sums[kElementsPerThreadPerIteration] = {0.0f}; constexpr int k_per_iter = kWarpSize * kElementsPerThreadPerIteration; // Elements processed per warp per iteration (e.g., 32*8 = 256) int k_id = 0; // Current position along the K dimension @@ -275,12 +313,12 @@ __global__ void __launch_bounds__(kWarpSize* kColsPerThreadBlock) MatMulFloat8bK // --- Intra-Thread Reduction --- // Sum the kElementsPerThreadPerIteration partial sums within each thread. - // Here we use float to accumulate to avoid precision loss. + // Always accumulate in float to avoid precision loss. float total_sum_thread = 0.0f; #pragma unroll for (int i = 0; i < kElementsPerThreadPerIteration; ++i) { - total_sum_thread += static_cast(sums[i]); + total_sum_thread += sums[i]; } // --- Inter-Thread Reduction (Warp Level) --- @@ -289,13 +327,12 @@ __global__ void __launch_bounds__(kWarpSize* kColsPerThreadBlock) MatMulFloat8bK // Allocate shared memory for CUB temporary storage (one per warp) __shared__ typename BlockReduce::TempStorage temp_storage[kColsPerThreadBlock]; - // Perform warp-level sum reduction. Use float in accumulation to avoid precision loss. + // Perform warp-level sum reduction. total_sum_thread = BlockReduce(temp_storage[warp_id]).Sum(total_sum_thread); // Lane 0 of each warp writes the final reduced sum to global memory if (lane_id == 0) { // Write result (cast back to T) - // Since m=1, output index is just n_id output[n_id] = static_cast(total_sum_thread); } } @@ -422,6 +459,20 @@ template bool TryMatMul8Bits( size_t shared_mem_per_block, cudaStream_t stream); +// Add template instantiation for nv_bfloat16 +template bool TryMatMul8Bits( + nv_bfloat16* output, + const nv_bfloat16* a_data, + const uint8_t* b_data_quant, + const nv_bfloat16* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + size_t shared_mem_per_block, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu index 098e3618beddd..9cbadf972e4d1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu @@ -23,12 +23,12 @@ __device__ inline float ScalarMulFloatOut(float a, float b) { template <> __device__ inline float ScalarMulFloatOut(half a, half b) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - return static_cast(a * b); - #else - // half multiplication not supported - return static_cast(a) * static_cast(b); - #endif +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return static_cast(a * b); +#else + // half multiplication not supported + return static_cast(a) * static_cast(b); +#endif } template <> @@ -95,7 +95,7 @@ __global__ void kgemm_4bit_inference_naive( reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; } else { - #pragma unroll +#pragma unroll for (int j = 0; j < (num_values_8bit); j++) if ((inner_idx_halved) + j < (K / 2)) local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; @@ -103,12 +103,12 @@ __global__ void kgemm_4bit_inference_naive( local_B_4bit[j] = 0b01110111; } } else { - #pragma unroll +#pragma unroll for (int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111; } for (int i = 0; i < 4; i++) { - #pragma unroll +#pragma unroll for (int k = 0; k < num_values_8bit / 4; k++) { local_B[k * 2] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4], local_absmax); local_B[k * 2 + 1] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F], local_absmax); @@ -126,7 +126,7 @@ __global__ void kgemm_4bit_inference_naive( reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; } } else { - #pragma unroll +#pragma unroll for (int k = 0; k < num_values_4bit / 4; k++) { if (inner_idx + (i * num_values_4bit / 4) + k < K) local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; @@ -135,8 +135,8 @@ __global__ void kgemm_4bit_inference_naive( } } - // accumulate in float; small performance hit for Ampere, but lower error for outputs - #pragma unroll +// accumulate in float; small performance hit for Ampere, but lower error for outputs +#pragma unroll for (int k = 0; k < num_values_4bit / 4; k++) { local_C += ScalarMulFloatOut(local_A[k], local_B[k]); } @@ -243,22 +243,22 @@ bool TryMatMulBnb4( return false; } - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 - Callkgemm_4bit_inference_naive( - reinterpret_cast(quant_map), - reinterpret_cast(output), - reinterpret_cast(a_data), - b_data_quant, - reinterpret_cast(absmax), - m, - n, - k, - block_size, - stream); - #else - Callkgemm_4bit_inference_naive( - quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream); - #endif +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + Callkgemm_4bit_inference_naive( + reinterpret_cast(quant_map), + reinterpret_cast(output), + reinterpret_cast(a_data), + b_data_quant, + reinterpret_cast(absmax), + m, + n, + k, + block_size, + stream); +#else + Callkgemm_4bit_inference_naive( + quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream); +#endif return true; } diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 71a84b877b8d1..8509892919639 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -8,6 +8,7 @@ #include "core/common/status.h" #include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/cuda/cuda_type_conversion.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "contrib_ops/cpu/utils/dump_tensor.h" #include "contrib_ops/cuda/quantization/matmul_nbits.cuh" @@ -38,7 +39,10 @@ template void MatMulNBits::InitGemmProfiler(int sm) { gemmProfiler_ = s_profilerManager.createGemmPluginProfiler(/*inference*/ false); + using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + KernelType cuda_kernel_type; if constexpr (std::is_same_v) { + cuda_kernel_type = nbits_ == 8 ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; if (has_zero_points_) { if (nbits_ == 8) { weightOnlyGemmRunner_ = std::make_shared>(); @@ -53,6 +57,7 @@ void MatMulNBits::InitGemmProfiler(int sm) { } } } else if constexpr (std::is_same_v) { + cuda_kernel_type = nbits_ == 8 ? KernelType::BF16Int8Groupwise : KernelType::BF16Int4Groupwise; if (has_zero_points_) { if (nbits_ == 8) { weightOnlyGemmRunner_ = std::make_shared>(); @@ -68,8 +73,6 @@ void MatMulNBits::InitGemmProfiler(int sm) { } } - using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; - KernelType cuda_kernel_type = nbits_ == 8 ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; gemmProfiler_->setCudaKernelType(cuda_kernel_type, sm); gemmProfiler_->setQuant(nbits_, has_bias_, has_zero_points_); gemmProfiler_->setGroupSize(block_size_); @@ -83,37 +86,38 @@ void MatMulNBits::RunGemmProfile(bool hasWeightOnlyCudaKernel, int min_m, int // Number of 16-bit elements after casting int8/int4 to fp16. int n_16b = N_ / (nbits_ == 8 ? 2 : 4); - gemmId_ = GemmIdCore(n_16b, K_, onnxruntime::llm::nvinfer::DataType::kHALF); + if constexpr (std::is_same_v) { + gemmId_ = GemmIdCore(n_16b, K_, onnxruntime::llm::nvinfer::DataType::kHALF); + } else if constexpr (std::is_same_v) { + gemmId_ = GemmIdCore(n_16b, K_, onnxruntime::llm::nvinfer::DataType::kBF16); + } GemmDims dims = {min_m, max_m, n_16b, K_}; gemmProfiler_->profileTactics(weightOnlyGemmRunner_, gemmId_.dtype, dims, gemmId_, hasWeightOnlyCudaKernel); } template -Status MatMulNBits::PrePack(const Tensor& /* tensor */, int /* input_idx */, AllocatorPtr /*alloc*/, +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; - return Status::OK(); -} - -template <> -Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool& is_packed, - PrePackedWeights* /*prepacked_weights*/) { - is_packed = false; - if (has_fpA_intB_gemm_) { - cudaStream_t stream = cudaStreamLegacy; // Use default stream for prepacking. - if (input_idx == MatMulNBits_Input_B) { - ORT_RETURN_IF_ERROR(PrePack_B(tensor, alloc, stream)); - is_packed = true; - } else if (input_idx == MatMulNBits_Input_Scale) { - ORT_RETURN_IF_ERROR(PrePack_Scale(tensor, alloc, stream)); - is_packed = true; - } else if (input_idx == MatMulNBits_Input_ZeroPoint) { - if (has_zero_points_) { - ORT_RETURN_IF_ERROR(PrePack_ZeroPoint(tensor, alloc, stream)); + if constexpr (std::is_same_v || std::is_same_v) { + if (has_fpA_intB_gemm_) { + cudaStream_t stream = cudaStreamLegacy; // Use default stream for prepacking. + if (input_idx == MatMulNBits_Input_B) { + ORT_RETURN_IF_ERROR(PrePack_B(tensor, alloc, stream)); + is_prepacked_weight_ = true; is_packed = true; + } else if (input_idx == MatMulNBits_Input_Scale) { + ORT_RETURN_IF_ERROR(PrePack_Scale(tensor, alloc, stream)); + is_prepacked_scale_ = true; + is_packed = true; + } else if (input_idx == MatMulNBits_Input_ZeroPoint) { + if (has_zero_points_) { + ORT_RETURN_IF_ERROR(PrePack_ZeroPoint(tensor, alloc, stream)); + is_prepacked_zero_point_ = true; + is_packed = true; + } } } } @@ -125,7 +129,7 @@ template Status MatMulNBits::PrePack_B([[maybe_unused]] const Tensor& tensor, [[maybe_unused]] AllocatorPtr alloc, [[maybe_unused]] cudaStream_t stream) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v) { size_t n = static_cast(N_); size_t k = static_cast(K_); @@ -175,7 +179,7 @@ template Status MatMulNBits::PrePack_Scale([[maybe_unused]] const Tensor& tensor, [[maybe_unused]] AllocatorPtr alloc, [[maybe_unused]] cudaStream_t stream) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v) { size_t n = static_cast(N_); size_t k = static_cast(K_); @@ -184,7 +188,7 @@ Status MatMulNBits::PrePack_Scale([[maybe_unused]] const Tensor& tensor, fpA_intB_scale_buffer_ = IAllocator::MakeUniquePtr(alloc, scale_bytes, true); // Transient buffer. - typedef typename ToCudaType::MappedType CudaT; + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; CudaT* transposed_scales = reinterpret_cast(fpA_intB_scale_buffer_.get()); onnxruntime::llm::kernels::fpA_intB_gemv::launch_transpose_scale_kernel( @@ -201,14 +205,14 @@ template Status MatMulNBits::PrePack_ZeroPoint([[maybe_unused]] const Tensor& tensor, [[maybe_unused]] AllocatorPtr alloc, [[maybe_unused]] cudaStream_t stream) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v) { size_t n = static_cast(N_); size_t k = static_cast(K_); size_t k_blocks = (k + block_size_ - 1) / block_size_; size_t scale_bytes = n * k_blocks * sizeof(T); - typedef typename ToCudaType::MappedType CudaT; + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; const CudaT* transposed_scales = reinterpret_cast(fpA_intB_scale_buffer_.get()); fpA_intB_zero_buffer_ = IAllocator::MakeUniquePtr(alloc, scale_bytes, true); // Transient buffer. @@ -245,11 +249,17 @@ Status MatMulNBits::PrePack_ZeroPoint([[maybe_unused]] const Tensor& tensor, template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { - const bool is_prepacked = has_fpA_intB_gemm_; + if constexpr (std::is_same_v) { + if (sm_ < 80) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "BFloat16 MatMulNBits is not supported on cuda device with compute capability < 8.0"); + } + } + const Tensor* a = ctx->Input(0); - const Tensor* b = is_prepacked ? nullptr : ctx->Input(1); - const Tensor* scales = is_prepacked ? nullptr : ctx->Input(2); - const Tensor* zero_points = is_prepacked ? nullptr : ctx->Input(3); + const Tensor* b = is_prepacked_weight_ ? nullptr : ctx->Input(1); + const Tensor* scales = is_prepacked_scale_ ? nullptr : ctx->Input(2); + const Tensor* zero_points = is_prepacked_zero_point_ ? nullptr : ctx->Input(3); const Tensor* reorder_idx = ctx->Input(4); const Tensor* bias = ctx->Input(5); @@ -261,9 +271,9 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { a, b, scales, zero_points, reorder_idx, bias, N_, K_, block_size_, nbits_)); const auto* a_data = a->Data(); - const uint8_t* blob_data = is_prepacked ? nullptr : b->Data(); - const auto* scales_data = is_prepacked ? nullptr : scales->Data(); - const auto* zero_points_data = (is_prepacked || zero_points == nullptr) ? nullptr : zero_points->DataRaw(); + const uint8_t* blob_data = is_prepacked_weight_ ? nullptr : b->Data(); + const auto* scales_data = is_prepacked_scale_ ? nullptr : scales->Data(); + const auto* zero_points_data = (is_prepacked_zero_point_ || zero_points == nullptr) ? nullptr : zero_points->DataRaw(); const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); @@ -281,7 +291,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { cudaStream_t stream = static_cast(ctx->GetComputeStream()->GetHandle()); - typedef typename ToCudaType::MappedType CudaT; + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; CudaT* out_data = reinterpret_cast(Y->MutableData()); int m = SafeInt(helper.M()); @@ -290,8 +300,13 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { DUMP_TENSOR_INIT(); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value || std::is_same::value) { if (has_fpA_intB_gemm_) { + // We expect weight/scale/zero_point(optional) inputs are initializers and have been prepacked. + // User could disable it by setting ORT_FPA_INTB_GEMM=0 if those tensors cannot be prepacked (It is rare). + ORT_ENFORCE(is_prepacked_weight_ && is_prepacked_scale_ && (is_prepacked_zero_point_ || !has_zero_points_), + "To use fpA_intB_gemm, prepacking must be done on weight, scale and zero point."); + auto const& bestTactic = gemmProfiler_->getBestConfig(m, gemmId_); #if ORT_LLM_VERBOSE > 1 @@ -301,7 +316,12 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { if (bestTactic->enableCudaKernel) { using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; - KernelType cuda_kernel_type = (nbits_ == 8) ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + KernelType cuda_kernel_type; + if constexpr (std::is_same::value) { + cuda_kernel_type = nbits_ == 8 ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + } else if constexpr (std::is_same::value) { + cuda_kernel_type = nbits_ == 8 ? KernelType::BF16Int8Groupwise : KernelType::BF16Int4Groupwise; + } void const* pre_quant_scale_ptr = nullptr; bool apply_alpha_in_advance = false; @@ -447,8 +467,8 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { DUMP_TENSOR_D("DeQuantized", b_data, N_, K_padded); - const CudaT alpha = ToCudaType::FromFloat(1.f); - const CudaT zero = ToCudaType::FromFloat(0.f); + const CudaT alpha = onnxruntime::cuda::OrtToCudaType::FromFloat(1.f); + const CudaT zero = onnxruntime::cuda::OrtToCudaType::FromFloat(0.f); if (helper.OutputOffsets().size() == 1) { CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( @@ -497,6 +517,18 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), MatMulNBits); +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + MatMulNBits); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index 3c8bd34c5e845..0d3558f91f03e 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -57,7 +57,7 @@ class MatMulNBits final : public CudaKernel { is_zero_points_scale_same_type_ = (zero_point_type == scale_type); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value || std::is_same::value) { int option = ParseEnvironmentVariableWithDefault(kFpAIntBGemmOption, 0); if ((option & (static_cast(nbits_) | kFpAIntBGemmOption_All)) != 0 && (block_size_ == 64 || block_size_ == 128) && @@ -68,7 +68,13 @@ class MatMulNBits final : public CudaKernel { sm_ >= 75) { if ((option & (kFpAIntBGemmOption_Gemv | kFpAIntBGemmOption_All)) != 0) { using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; - KernelType cuda_kernel_type = (nbits_ == 8) ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + KernelType cuda_kernel_type; + if constexpr (std::is_same::value) { + cuda_kernel_type = (nbits_ == 8) ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + } else if constexpr (std::is_same::value) { + cuda_kernel_type = (nbits_ == 8) ? KernelType::BF16Int8Groupwise : KernelType::BF16Int4Groupwise; + } + if (onnxruntime::llm::kernels::fpA_intB_gemv::is_supported(sm_, cuda_kernel_type)) { has_fpA_intB_gemv_ = true; } @@ -118,6 +124,10 @@ class MatMulNBits final : public CudaKernel { bool has_fpA_intB_gemv_{false}; bool has_fpA_intB_gemm_{false}; + bool is_prepacked_weight_{false}; + bool is_prepacked_scale_{false}; + bool is_prepacked_zero_point_{false}; + WeightOnlyGemmRunnerPtr weightOnlyGemmRunner_{nullptr}; mutable GemmProfilerPtr gemmProfiler_{nullptr}; GemmIdCore gemmId_{}; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_common.cuh b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_common.cuh index 2dbd3009ab8f9..614f3ba0675da 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_common.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_common.cuh @@ -61,11 +61,11 @@ union U1S2 { }; __device__ inline __half2 hmul2bk(const __half2 a, const __half2 b) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - return __hmul2(a, b); - #else - return __half2{(half)((float)a.x * (float)b.x), (half)((float)a.y * (float)b.y)}; - #endif +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return __hmul2(a, b); +#else + return __half2{(half)((float)a.x * (float)b.x), (half)((float)a.y * (float)b.y)}; +#endif } __device__ inline char2 QuantizeHalf2Char2(const __half2 xy, const __half2 inverse_scale2) { diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_layer_norm_impl.cu index 469efcbbab12c..a056887c030f7 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_layer_norm_impl.cu @@ -73,9 +73,9 @@ __global__ void QOrderedLayerNormRowKernel(const int8_t* __restrict__ src, const template Status QOrderedLayerNorm(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, cublasLtOrder_t order, - const int8_t* src, const float src_scale, int8_t* dst, const float dst_scale, - const T* gamma, const T* beta, const float epsilon, - const unsigned batch, const unsigned rows, const unsigned cols) { + const int8_t* src, const float src_scale, int8_t* dst, const float dst_scale, + const T* gamma, const T* beta, const float epsilon, + const unsigned batch, const unsigned rows, const unsigned cols) { // The implementation only supports Row major tensor data ordering for now ORT_RETURN_IF(order != CUBLASLT_ORDER_ROW, "Order current not supported!"); @@ -87,18 +87,18 @@ Status QOrderedLayerNorm(cudaStream_t stream, const cudaDeviceProp& /*device_pro QOrderedLayerNormRowKernel<<>>( src, src_scale, dst, dst_scale, gamma, beta, epsilon, rows, cols); - return CUDA_CALL(cudaGetLastError()); + return CUDA_CALL(cudaGetLastError()); } template Status QOrderedLayerNorm(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, cublasLtOrder_t order, - const int8_t* src, const float src_scale, int8_t* dst, const float dst_scale, - const float* gamma, const float* beta, const float epsilon, - const unsigned batch, const unsigned rows, const unsigned cols); + const int8_t* src, const float src_scale, int8_t* dst, const float dst_scale, + const float* gamma, const float* beta, const float epsilon, + const unsigned batch, const unsigned rows, const unsigned cols); template Status QOrderedLayerNorm<__half>(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, cublasLtOrder_t order, - const int8_t* src, const float src_scale, int8_t* dst, const float dst_scale, - const __half* gamma, const __half* beta, const float epsilon, - const unsigned batch, const unsigned rows, const unsigned cols); + const int8_t* src, const float src_scale, int8_t* dst, const float dst_scale, + const __half* gamma, const __half* beta, const float epsilon, + const unsigned batch, const unsigned rows, const unsigned cols); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/sparse/block_mask.cu b/onnxruntime/contrib_ops/cuda/sparse/block_mask.cu index 1e6461a145144..9bfaee92b410a 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/block_mask.cu +++ b/onnxruntime/contrib_ops/cuda/sparse/block_mask.cu @@ -44,7 +44,7 @@ __global__ void MaskToCSR(const int* mask, int* csr_row_indices, int* csr_col_in } __syncthreads(); - // The starting index of current row in csr_col_indices + // The starting index of current row in csr_col_indices int offset = shared_row_indices[row]; // Output row indices. diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu index 7c3f2963207e6..e3f27238660eb 100644 --- a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu @@ -14,18 +14,18 @@ namespace contrib { namespace cuda { __global__ void DynamicTimeWarpingInitCost(float* cost_buffer, int8_t* trace_buffer, size_t cols_plus_1) { - int r = blockIdx.x; - cost_buffer += cols_plus_1 * r; + int r = blockIdx.x; + cost_buffer += cols_plus_1 * r; + for (size_t i = threadIdx.x; i < cols_plus_1; i += blockDim.x) { + cost_buffer[i] = FLT_MAX; + } + if (r == 0) { for (size_t i = threadIdx.x; i < cols_plus_1; i += blockDim.x) { - cost_buffer[i] = FLT_MAX; - } - if (r == 0) { - for (size_t i = threadIdx.x; i < cols_plus_1; i += blockDim.x) { - trace_buffer[i] = 2; - } + trace_buffer[i] = 2; } - if (threadIdx.x == 0) trace_buffer[cols_plus_1 * r] = 1; - if (threadIdx.x == 0 && r == 0) *cost_buffer = 0.0f; + } + if (threadIdx.x == 0) trace_buffer[cols_plus_1 * r] = 1; + if (threadIdx.x == 0 && r == 0) *cost_buffer = 0.0f; } __global__ void DynamicTimeWarpingKernel( @@ -36,54 +36,60 @@ __global__ void DynamicTimeWarpingKernel( float* cost_buffer, int8_t* trace_buffer, int32_t* result_buffer, - size_t* result_len_device -) { + size_t* result_len_device) { const int diag_max = static_cast(rows + cols); for (int d = 1; d <= diag_max; d++) { for (int c = threadIdx.x + 1; c <= cols; c += blockDim.x) { - int r = d - c; - if (r >= 1 && r <= rows) { - int cost_idx = ((r - 1) * (cols + 1) + (c - 1)); //[r - 1, c - 1] - const float c0 = cost_buffer[cost_idx]; - const float c1 = cost_buffer[cost_idx + 1]; // [r - 1, c] - const float c2 = cost_buffer[cost_idx + cols + 1]; // [r, c - 1] + int r = d - c; + if (r >= 1 && r <= rows) { + int cost_idx = ((r - 1) * (cols + 1) + (c - 1)); //[r - 1, c - 1] + const float c0 = cost_buffer[cost_idx]; + const float c1 = cost_buffer[cost_idx + 1]; // [r - 1, c] + const float c2 = cost_buffer[cost_idx + cols + 1]; // [r, c - 1] - float cost; - int8_t t; - if (c0 < c1 && c0 < c2) { - cost = c0; - t = 0; - } else if (c1 < c0 && c1 < c2) { - cost = c1; - t = 1; - } else { - cost = c2; - t = 2; - } - cost_idx += ((cols + 1) + 1); - cost_buffer[cost_idx] = cost + input[(r - 1) * cols + (c - 1)]; - trace_buffer[cost_idx] = t; + float cost; + int8_t t; + if (c0 < c1 && c0 < c2) { + cost = c0; + t = 0; + } else if (c1 < c0 && c1 < c2) { + cost = c1; + t = 1; + } else { + cost = c2; + t = 2; } + cost_idx += ((cols + 1) + 1); + cost_buffer[cost_idx] = cost + input[(r - 1) * cols + (c - 1)]; + trace_buffer[cost_idx] = t; + } } __syncthreads(); } - //back tracing, reverse append to result buffer + // back tracing, reverse append to result buffer if (threadIdx.x == 0) { int r = rows - 1; int c = cols - 1; - int pos = static_cast(max_index_len); // reverse put + int pos = static_cast(max_index_len); // reverse put while (r >= 0 && c >= 0) { - --pos; - result_buffer[pos] = r; - result_buffer[max_index_len + pos] = c; - const int trace_index = (r + 1) * (cols + 1) + (c + 1); - int8_t t = trace_buffer[trace_index]; - switch (t) { - case 0: r -= 1; c -= 1; break; - case 1: r -= 1; break; - default: c -= 1; break; - } + --pos; + result_buffer[pos] = r; + result_buffer[max_index_len + pos] = c; + const int trace_index = (r + 1) * (cols + 1) + (c + 1); + int8_t t = trace_buffer[trace_index]; + switch (t) { + case 0: + r -= 1; + c -= 1; + break; + case 1: + r -= 1; + break; + default: + c -= 1; + break; + } } *result_len_device = max_index_len - static_cast(pos); } @@ -92,10 +98,10 @@ __global__ void DynamicTimeWarpingKernel( size_t GetDynamicTimeWarpingBufferSize(size_t batch, size_t rows, size_t cols, size_t& max_index_len) { max_index_len = rows + cols + 1; size_t cost_buffer_size = ((rows + 1) * (cols + 1)); - return batch * max_index_len * 2 * sizeof(int32_t) + // two index arrays - sizeof(int64_t) + // final index array length - batch* cost_buffer_size * sizeof(float) + // cost buffer - batch* cost_buffer_size * sizeof(int8_t); // trace buffer + return batch * max_index_len * 2 * sizeof(int32_t) + // two index arrays + sizeof(int64_t) + // final index array length + batch * cost_buffer_size * sizeof(float) + // cost buffer + batch * cost_buffer_size * sizeof(int8_t); // trace buffer } Status LaunchDynamicTimeWarping( @@ -106,8 +112,7 @@ Status LaunchDynamicTimeWarping( size_t cols, const float* input, void* buffer, - size_t& result_len -) { + size_t& result_len) { ORT_ENFORCE(batch == 1); size_t max_index_len = rows + cols + 1; int32_t* result_buffer = (int32_t*)buffer; @@ -117,19 +122,19 @@ Status LaunchDynamicTimeWarping( dim3 block(device_prop.maxThreadsPerBlock); dim3 grid_init((unsigned)SafeInt(rows + 1), (unsigned)SafeInt(batch)); - DynamicTimeWarpingInitCost<<>>(cost_buffer, trace_buffer, cols+1); + DynamicTimeWarpingInitCost<<>>(cost_buffer, trace_buffer, cols + 1); ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); dim3 grid(1, (unsigned)SafeInt(batch)); DynamicTimeWarpingKernel<<>>( - rows, - cols, - max_index_len, - input, - cost_buffer, - trace_buffer, - result_buffer, - result_len_device_buf); + rows, + cols, + max_index_len, + input, + cost_buffer, + trace_buffer, + result_buffer, + result_len_device_buf); ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); ORT_RETURN_IF_ERROR(CUDA_CALL(cudaMemcpyAsync(&result_len, result_len_device_buf, sizeof(size_t), cudaMemcpyDeviceToHost, stream))); diff --git a/onnxruntime/contrib_ops/cuda/tensor/image_scaler_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/image_scaler_impl.cu index a63cd4755c36b..bb5ea3e68a610 100644 --- a/onnxruntime/contrib_ops/cuda/tensor/image_scaler_impl.cu +++ b/onnxruntime/contrib_ops/cuda/tensor/image_scaler_impl.cu @@ -58,5 +58,5 @@ SPECIALIZED_IMPL(double) SPECIALIZED_IMPL(half) } // namespace cuda -} //namespace contrib +} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu index a3c93ceb33c46..03f9edb3e6317 100644 --- a/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu @@ -17,12 +17,11 @@ __global__ void UnfoldTensorKernel( const T* input, T* output, int64_t N, - int64_t unfold_size, // stride_tailing_dim_dst - int64_t tailing_dims_size, // stride_fold_dim_dst = tailing_dims_size * unfold_size, stride_append_dim_src = tailing_dims_size + int64_t unfold_size, // stride_tailing_dim_dst + int64_t tailing_dims_size, // stride_fold_dim_dst = tailing_dims_size * unfold_size, stride_append_dim_src = tailing_dims_size int64_t stride_leading_dst, int64_t stride_fold_dim_src, - int64_t stride_leading_src -) { + int64_t stride_leading_src) { int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; if (idx >= N) return; @@ -38,7 +37,6 @@ __global__ void UnfoldTensorKernel( output[idx] = input[idx_src]; } - Status LaunchUnfoldTensor( cudaStream_t stream, const cudaDeviceProp& device_prop, @@ -49,8 +47,7 @@ Status LaunchUnfoldTensor( int64_t unfold_dim_size, int64_t tailing_dims_size, int64_t unfold_size, - int64_t step_size -) { + int64_t step_size) { int64_t TPB = device_prop.maxThreadsPerBlock; int64_t unfold_dim_size_dst = (unfold_dim_size - unfold_size) / step_size + 1; int64_t N = leading_dims_size * unfold_dim_size_dst * tailing_dims_size * unfold_size; @@ -65,32 +62,32 @@ Status LaunchUnfoldTensor( dim3 grid((unsigned)SafeInt(num_blocks)); switch (element_size) { case 1: - UnfoldTensorKernel<<>>( - (const int8_t*)input, (int8_t*)output, N, unfold_size, - tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); - break; + UnfoldTensorKernel<<>>( + (const int8_t*)input, (int8_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; case 2: - UnfoldTensorKernel<<>>( - (const int16_t*)input, (int16_t*)output, N, unfold_size, - tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); - break; + UnfoldTensorKernel<<>>( + (const int16_t*)input, (int16_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; case 4: - UnfoldTensorKernel<<>>( - (const int32_t*)input, (int32_t*)output, N, unfold_size, - tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); - break; + UnfoldTensorKernel<<>>( + (const int32_t*)input, (int32_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; case 8: - UnfoldTensorKernel<<>>( - (const int64_t*)input, (int64_t*)output, N, unfold_size, - tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); - break; + UnfoldTensorKernel<<>>( + (const int64_t*)input, (int64_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; case 16: - UnfoldTensorKernel<<>>( - (const float4*)input, (float4*)output, N, unfold_size, - tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); - break; + UnfoldTensorKernel<<>>( + (const float4*)input, (float4*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; default: - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unsupported element_size"); + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unsupported element_size"); } return CUDA_CALL(cudaGetLastError()); diff --git a/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu b/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu index b2969194ff400..aa3014f326502 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu @@ -5,7 +5,6 @@ #include - #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/providers/cuda/cu_inc/common.cuh" diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh index 66d21e12c2740..213940f132963 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh @@ -44,7 +44,7 @@ struct GemmPermuteParams : onnxruntime::rocm::tunable::OpParams { int3 bias_strides; - const T* ones; // used for broadcasting bias if the underlying algorithm does not support strides + const T* ones; // used for broadcasting bias if the underlying algorithm does not support strides T* workspace_buffer; }; diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh index 3055d59cf17ae..8255e70d27e48 100644 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh @@ -111,9 +111,9 @@ template Status ElementwiseOp::operator()(const ElementwiseParams* params) { dim3 blocks(CeilDiv(params->input_length, ThreadsPerBlock * VecSize)); ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, - params->bias, params->bias_length, - params->output); + params->input, params->input_length, + params->bias, params->bias_length, + params->output); return HIP_CALL(hipGetLastError()); } @@ -239,18 +239,18 @@ ElementwiseTunableOp::ElementwiseTunableOp() { } // namespace contrib } // namespace onnxruntime -#define ELEMENTWISE_KERNEL_IMPL(Fn, T) \ - namespace onnxruntime { \ - namespace contrib { \ - namespace rocm { \ - template Status LaunchElementwiseKernel( \ - RocmTuningContext * tuning_ctx, Stream* stream, \ - const T* input, int input_length, \ - const T* bias, int bias_length, \ - T* output); \ - namespace internal { \ - template class ElementwiseTunableOp; \ - } \ - } \ - } \ +#define ELEMENTWISE_KERNEL_IMPL(Fn, T) \ + namespace onnxruntime { \ + namespace contrib { \ + namespace rocm { \ + template Status LaunchElementwiseKernel( \ + RocmTuningContext * tuning_ctx, Stream* stream, \ + const T* input, int input_length, \ + const T* bias, int bias_length, \ + T* output); \ + namespace internal { \ + template class ElementwiseTunableOp; \ + } \ + } \ + } \ } diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh index 992bba0fc5e6b..77f53f9eed027 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh @@ -30,8 +30,8 @@ namespace internal { #ifdef USE_COMPOSABLE_KERNEL -using onnxruntime::rocm::CKDataTypeAdaptor; using onnxruntime::rocm::CKBlasOpAdaptor; +using onnxruntime::rocm::CKDataTypeAdaptor; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index e644b7e903138..85aef55908506 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -497,7 +497,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.head_size, // v head size GetCkFmhaDataTypeString(), !parameters.is_first_prompt, // true, // is_group_mode - true, // is_v_rowmajor ? dim is fastest : seq is fastest + true, // is_v_rowmajor ? dim is fastest : seq is fastest mask.type, bias_type, false, // has_lse diff --git a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh index 3f9183ef10828..2eeb7c3e8f279 100644 --- a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh @@ -155,7 +155,7 @@ __device__ inline void SimplifiedLayerNormVec( const VecV gamma_v = *reinterpret_cast(gamma + i); VecV output_v = *reinterpret_cast(output + idx); - #pragma unroll +#pragma unroll for (int k = 0; k < ILP; k++) { output_v.val[k] = U(gamma_v.val[k]) * U(output_v.val[k]) * rsigma; } @@ -191,10 +191,9 @@ __device__ inline void LayerNormVec( const VecV gamma_v = *reinterpret_cast(gamma + i); VecV output_v = *reinterpret_cast(output + idx); - #pragma unroll +#pragma unroll for (int k = 0; k < ILP; k++) { - output_v.val[k] = (beta != nullptr) ? U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma + U(beta_v.val[k]) : - U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma; + output_v.val[k] = (beta != nullptr) ? U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma + U(beta_v.val[k]) : U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma; } *(reinterpret_cast(output + idx)) = output_v; } @@ -228,10 +227,9 @@ __device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePa const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); VecV output_v; - #pragma unroll +#pragma unroll for (int i = 0; i < ILP; i++) { - output_v.val[i] = (beta != nullptr) ? U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma + U(beta_v.val[i]) : - U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma; + output_v.val[i] = (beta != nullptr) ? U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma + U(beta_v.val[i]) : U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma; } *(reinterpret_cast(output + idx)) = output_v; } @@ -259,7 +257,7 @@ __device__ inline void SimplifiedLayerNormSmall(const T* input_v, const U& threa const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); VecV output_v; - #pragma unroll +#pragma unroll for (int i = 0; i < ILP; i++) { output_v.val[i] = U(gamma_v.val[i]) * U(input_v[i]) * rsigma; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh index 4cb371fdcf960..68f7d47282845 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -75,8 +75,8 @@ using device_normalization_f16_instances = DeviceNormalizationFwdImpl, DeviceNormalizationFwdImpl, DeviceNormalizationFwdImpl - // clang-format on - >; + // clang-format on + >; // Use this function to get implementation template tile_A: array; // 64 x 32 - RxC var tile_B: array; // 64 x 32 - RxC - - fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, c_idx:u32) { - let a_global = tile_base + row; - if (a_global >= uniforms.M) { - return; - } - // Each call loads 8 columns, starting at col. - let col = c_idx * 8; - // 256 threads need to load 64 x 32. 4 threads per row or 8 col per thread. - for (var col_offset:u32 = 0; col_offset < 8; col_offset++) - { - tile_A[row * tile_k + col + col_offset] = component_type(input_a[a_global*uniforms.K + k_idx + col + col_offset]); - } - } )ADDNL_FN" << GenerateZeroPointReadingCode(nbits, has_zero_points, "component_type"); if (nbits == 4) { shader.AdditionalImplementation() << R"ADDNL_FN( @@ -176,17 +161,17 @@ Status GenerateShaderCodeOnIntel(ShaderHelper& shader, uint32_t nbits, int32_t c var matC03: subgroup_matrix_result; for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { // Load Phase - loadSHMA(a_global_base, kidx, local_idx / 4, local_idx % 4); loadSHMB(b_global_base, kidx, local_idx / 4, local_idx % 4); workgroupBarrier(); for (var step: u32 = 0; step < tile_k; step += k_dim) { - // Load to local memory phase - let matrix_a_offset = subtile_id * subtile_rows * tile_k + step; + // Load A from global memory. + let matrix_a_offset = (a_global_base + subtile_id * subtile_rows) * uniforms.K + kidx + step; // Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride - var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); + var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&input_a, matrix_a_offset, false, uniforms.K); + // Load B from shared local memory. // tile_B is stored as column major. // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] var matrix_b_offset = step; diff --git a/onnxruntime/core/dlpack/dlpack_converter.cc b/onnxruntime/core/dlpack/dlpack_converter.cc index 652414b8d693a..cb7fe9af5d1ae 100644 --- a/onnxruntime/core/dlpack/dlpack_converter.cc +++ b/onnxruntime/core/dlpack/dlpack_converter.cc @@ -191,11 +191,7 @@ DLDevice GetDlpackDevice(const OrtValue& ort_value, const int64_t& device_id) { device.device_type = DLDeviceType::kDLCPU; break; case OrtDevice::GPU: -#ifdef USE_ROCM - device.device_type = DLDeviceType::kDLROCM; -#else device.device_type = DLDeviceType::kDLCUDA; -#endif break; default: ORT_THROW("Cannot pack tensors on this device."); diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 5140d3ffaefff..c2ff70b8e9808 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -1,7 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/narrow.h" +#include "core/common/parse_string.h" #include "core/common/safeint.h" +#include "core/common/status.h" #include "core/framework/allocator.h" #include "core/mlas/inc/mlas.h" #include "core/framework/utils.h" @@ -15,6 +18,51 @@ #include "core/framework/bfc_arena.h" +using Status = onnxruntime::common::Status; + +Status OrtArenaCfg::FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& cfg) { + cfg = OrtArenaCfg{}; // reset to default values + + const auto from_string = [](const std::string& key, const std::string& str, auto& value) -> Status { + if (!onnxruntime::ParseStringWithClassicLocale(str, value).IsOK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to parse value for ", key, " from ", str); + } + + return Status::OK(); + }; + + if (auto it = kvps.entries.find(ConfigKeyNames::ArenaExtendStrategy); it != kvps.entries.end()) { + ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.arena_extend_strategy)); + } + + if (auto it = kvps.entries.find(ConfigKeyNames::InitialChunkSizeBytes); it != kvps.entries.end()) { + ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.initial_chunk_size_bytes)); + } + + if (auto it = kvps.entries.find(ConfigKeyNames::MaxDeadBytesPerChunk); it != kvps.entries.end()) { + ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.max_dead_bytes_per_chunk)); + } + + if (auto it = kvps.entries.find(ConfigKeyNames::InitialGrowthChunkSizeBytes); it != kvps.entries.end()) { + ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.initial_growth_chunk_size_bytes)); + } + + if (auto it = kvps.entries.find(ConfigKeyNames::MaxPowerOfTwoExtendBytes); it != kvps.entries.end()) { + ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.max_power_of_two_extend_bytes)); + } + + if (auto it = kvps.entries.find(ConfigKeyNames::MaxMem); it != kvps.entries.end()) { + ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.max_mem)); + } + + if (!cfg.IsValid()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid arena configuration. Please check the values provided."); + } + + return Status::OK(); +} + namespace onnxruntime { // private helper for calculation so SafeInt usage doesn't bleed into the public allocator.h header diff --git a/onnxruntime/core/framework/data_transfer_manager.cc b/onnxruntime/core/framework/data_transfer_manager.cc index 29b6e6cc257dc..c0b3dab6e04f2 100644 --- a/onnxruntime/core/framework/data_transfer_manager.cc +++ b/onnxruntime/core/framework/data_transfer_manager.cc @@ -16,6 +16,20 @@ Status DataTransferManager::RegisterDataTransfer(std::unique_ptr return Status::OK(); } +Status DataTransferManager::UnregisterDataTransfer(IDataTransfer* data_transfer) { + auto iter = std::find_if(datatransfers_.begin(), datatransfers_.end(), + [&data_transfer](const std::unique_ptr& dt) { + return dt.get() == data_transfer; + }); + + if (iter != datatransfers_.end()) { + datatransfers_.erase(iter); + } + + // ignore if not found + return Status::OK(); +} + const IDataTransfer* DataTransferManager::GetDataTransfer(const OrtDevice& src_device, const OrtDevice& dst_device) const { for (auto& data_transfer : datatransfers_) { if (!data_transfer->CanCopy(src_device, dst_device)) { diff --git a/onnxruntime/core/framework/data_transfer_manager.h b/onnxruntime/core/framework/data_transfer_manager.h index d1dbf85bc7f41..d11da776f6268 100644 --- a/onnxruntime/core/framework/data_transfer_manager.h +++ b/onnxruntime/core/framework/data_transfer_manager.h @@ -17,6 +17,7 @@ class DataTransferManager { // static DataTransferManager& Instance(); common::Status RegisterDataTransfer(std::unique_ptr data_transfer); + common::Status UnregisterDataTransfer(IDataTransfer* data_transfer); const IDataTransfer* GetDataTransfer(const OrtDevice& src_device, const OrtDevice& dst_device) const; diff --git a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc index 7bd825a9b0bb1..dc419b6621913 100644 --- a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc +++ b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc @@ -406,7 +406,7 @@ void DumpTensor( } else { std::cout << tensor_location << "\n"; -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_DML) const auto data_type = tensor.DataType(); // Dumping GPU only when cuda is enabled. if (tensor_location.device.Type() == OrtDevice::GPU) { diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 2081b8c3c9344..efc12ef8dd0e8 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -911,13 +911,16 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers size_t ini_size_threshold = ep_context_gen_options.output_external_initializer_size_threshold; std::filesystem::path external_ini_path = ep_context_gen_options.output_external_initializers_file_path; + bool force_embed_external_ini = false; if (external_ini_path.empty()) { - // Set the threshold to the max so all initializers are forced into the Onnx file + // if no external ini file specified, set force_embed_external_ini to true to avoid intermedia file creation + // and force all initializers embed into the Onnx file ini_size_threshold = SIZE_MAX; - external_ini_path = "./model_ext_ini.bin"; + force_embed_external_ini = true; } ModelSavingOptions model_saving_options{ini_size_threshold}; + model_saving_options.force_embed_external_ini = force_embed_external_ini; if (saving_to_buffer) { ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve()); diff --git a/onnxruntime/core/framework/plugin_data_transfer.cc b/onnxruntime/core/framework/plugin_data_transfer.cc new file mode 100644 index 0000000000000..f753f00206c5d --- /dev/null +++ b/onnxruntime/core/framework/plugin_data_transfer.cc @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/plugin_data_transfer.h" + +#include "core/framework/error_code_helper.h" + +namespace onnxruntime { +namespace plugin_ep { + +namespace { +static const std::function no_op_deleter = [](void*) {}; +static const MLDataType ml_tensor_type = DataTypeImpl::GetType(); +} // namespace + +Status DataTransfer::CopyTensors(const std::vector& src_dst_pairs) const { + // need to wrap the src/dst Tensor instances in OrtValue as the ORT API doesn't expose an OrtTensor. + // Adding an OrtTensor to the API would also require adding getters for type/shape/data. + // Those already exist for OrtValue so in order to minimize the API surface area we pay the price of a + // const_cast to convert the `const Tensor*` src to an OrtValue. + std::vector values; + values.resize(src_dst_pairs.size() * 2); + + for (size_t i = 0; i < src_dst_pairs.size(); ++i) { + const auto& pair = src_dst_pairs[i]; + + // we need to remove the const from the src to wrap it in an OrtValue. + // it's passed to the impl as a const OrtValue, and the deleter is a no-op so this should be safe. + Tensor* src_tensor = const_cast(&(pair.src.get())); + values[i * 2].Init(static_cast(src_tensor), ml_tensor_type, no_op_deleter); + values[i * 2 + 1].Init(static_cast(&pair.dst.get()), ml_tensor_type, no_op_deleter); + } + + std::vector src_values; + std::vector dst_values; + std::vector streams; + src_values.reserve(src_dst_pairs.size()); + dst_values.reserve(src_dst_pairs.size()); + streams.reserve(src_dst_pairs.size()); + + for (size_t i = 0; i < src_dst_pairs.size(); ++i) { + src_values.push_back(&values[i * 2]); + dst_values.push_back(&values[i * 2 + 1]); + streams.push_back(nullptr); // static_cast(src_dst_pairs[i].src_stream)); + } + + auto* status = impl_.CopyTensors(&impl_, src_values.data(), dst_values.data(), streams.data(), + src_dst_pairs.size()); + + return ToStatusAndRelease(status); +} + +// optimized version for a single copy. see comments above in CopyTensors regarding the OrtValue usage and const_cast +Status DataTransfer::CopyTensorImpl(const Tensor& src_tensor, Tensor& dst_tensor, onnxruntime::Stream* /*stream*/) const { + OrtValue src, dst; + Tensor* src_tensor_ptr = const_cast(&src_tensor); + src.Init(static_cast(src_tensor_ptr), ml_tensor_type, no_op_deleter); + dst.Init(static_cast(&dst_tensor), ml_tensor_type, no_op_deleter); + const OrtValue* src_ptr = &src; + OrtValue* dst_ptr = &dst; + OrtSyncStream* stream_ptr = nullptr; // static_cast(stream); + auto* status = impl_.CopyTensors(&impl_, &src_ptr, &dst_ptr, &stream_ptr, 1); + + return ToStatusAndRelease(status); +} + +} // namespace plugin_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/plugin_data_transfer.h b/onnxruntime/core/framework/plugin_data_transfer.h new file mode 100644 index 0000000000000..e8ad29f4c0609 --- /dev/null +++ b/onnxruntime/core/framework/plugin_data_transfer.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/framework/data_transfer.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/ort_value.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/abi_devices.h" + +namespace onnxruntime { +namespace plugin_ep { + +/// +/// Class to implement IDataTransfer for plugin execution providers. +/// It uses the OrtDataTransferImpl from the plugin EP factory to implement the data transfer functionality. +/// +class DataTransfer : public IDataTransfer { + public: + DataTransfer(OrtDataTransferImpl& impl) + : impl_{impl} { + } + + bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override { + const OrtMemoryDevice* src_memory_device = static_cast(&src_device); + const OrtMemoryDevice* dst_memory_device = static_cast(&dst_device); + + return impl_.CanCopy(&impl_, src_memory_device, dst_memory_device); + } + + Status CopyTensor(const Tensor& src, Tensor& dst) const override { + return CopyTensorImpl(src, dst, nullptr); + } + + Status CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const override { + return CopyTensorImpl(src, dst, &stream); + } + + Status CopyTensors(const std::vector& src_dst_pairs) const override; + + ~DataTransfer() override { + impl_.Release(&impl_); + } + + private: + Status CopyTensorImpl(const Tensor& src, Tensor& dst, onnxruntime::Stream* stream = nullptr) const; + + OrtDataTransferImpl& impl_; +}; +} // namespace plugin_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/run_options.cc b/onnxruntime/core/framework/run_options.cc index cb07cc22b1b2f..c55c7094a0585 100644 --- a/onnxruntime/core/framework/run_options.cc +++ b/onnxruntime/core/framework/run_options.cc @@ -63,6 +63,19 @@ ORT_API_STATUS_IMPL(OrtApis::AddRunConfigEntry, _Inout_ OrtRunOptions* options, return onnxruntime::ToOrtStatus(options->config_options.AddConfigEntry(config_key, config_value)); } +ORT_API_STATUS_IMPL(OrtApis::GetRunConfigEntry, _In_ const OrtRunOptions* options, + _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value_out) { + API_IMPL_BEGIN + const auto& config_options = options->config_options.GetConfigOptionsMap(); + if (auto it = config_options.find(config_key); it != config_options.end()) { + *config_value_out = it->second.c_str(); + } else { + *config_value_out = nullptr; + } + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, const _In_ OrtLoraAdapter* adapter) { API_IMPL_BEGIN diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index e92861fc4de63..46ec81abecce0 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -202,6 +202,13 @@ struct OrtNode { /// A status indicating success or an error. virtual onnxruntime::Status GetImplicitInputs(std::unique_ptr& implicit_inputs) const = 0; + /// + /// Gets the node's attributes as an array of OrtOpAttr elements wrapped in an OrtArrayOfConstObjects. + /// + /// Output parameter set to the node's attributes. + /// A status indicating success or an error. + virtual onnxruntime::Status GetAttributes(std::unique_ptr& attrs) const = 0; + /// /// Gets the node's subgraphs (e.g., subgraphs contained by an If or Loop node). /// diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index bcef5fda9c0b4..d1133a445ebfa 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -91,6 +91,20 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_inputs, ep_node_inputs); ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_outputs, ep_node_outputs); + const auto& node_attrs = node.GetAttributes(); + std::unordered_map> ep_node_attributes_map; + std::vector ep_node_attributes; + + if (node_attrs.size() > 0) { + ep_node_attributes.reserve(node_attrs.size()); + + for (const auto& item : node_attrs) { + auto attr = std::make_unique(item.second); // Copy AttributeProto and owned by this EpNode object. + ep_node_attributes.push_back(reinterpret_cast(attr.get())); + ep_node_attributes_map.emplace(item.first, std::move(attr)); + } + } + std::vector ep_node_subgraphs; std::vector ep_node_implicit_inputs; @@ -115,6 +129,8 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, ep_node->inputs_ = std::move(ep_node_inputs); ep_node->outputs_ = std::move(ep_node_outputs); + ep_node->attributes_map_ = std::move(ep_node_attributes_map); + ep_node->attributes_ = std::move(ep_node_attributes); ep_node->implicit_inputs_ = std::move(ep_node_implicit_inputs); ep_node->subgraphs_ = std::move(ep_node_subgraphs); @@ -169,6 +185,17 @@ Status EpNode::GetImplicitInputs(std::unique_ptr& result return Status::OK(); } +Status EpNode::GetAttributes(std::unique_ptr& result) const { + result = std::make_unique(ORT_TYPE_TAG_OrtOpAttr); + result->storage.reserve(attributes_.size()); + + for (const OrtOpAttr* attr : attributes_) { + result->storage.push_back(attr); + } + + return Status::OK(); +} + Status EpNode::GetSubgraphs(std::unique_ptr& result) const { result = std::make_unique(ORT_TYPE_TAG_OrtGraph); result->storage.reserve(subgraphs_.size()); @@ -197,6 +224,15 @@ gsl::span EpNode::GetOutputsSpan() const { return outputs_; } +const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { + auto iter = attributes_map_.find(name); + if (iter == attributes_map_.end()) { + return nullptr; + } else { + return reinterpret_cast(iter->second.get()); + } +} + // // EpValueInfo // @@ -485,12 +521,18 @@ Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr< initializer_value_infos.push_back(value_info); - // Temporary: Copy onnx::TensorProto into OrtValue objects owned by this EpGraph. - // TODO: Remove this logic once a separate PR that updates onnxruntime::Graph to store initializers as - // OrtValue instances is merged. + // Initialize OrtValue for the initializer. auto initializer_value = std::make_unique(); - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), *tensor_proto, - initializer_allocator, *initializer_value)); + bool graph_has_ortvalue = graph_viewer.GetGraph().GetOrtValueInitializer(initializer_name, *initializer_value, + /*check_outer_scope*/ false); + + if (!graph_has_ortvalue) { + // onnxruntime::Graph does not have an OrtValue for this initializer, so create one from the TensorProto. + // This should only happen for small initializers that are needed for ONNX shape inferencing. + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), *tensor_proto, + initializer_allocator, *initializer_value)); + } + initializer_values.emplace(value_info->GetName(), std::move(initializer_value)); } @@ -538,24 +580,27 @@ Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr< EpValueInfo* outer_value_info = value_info_iter->second.get(); bool is_constant = false; + auto outer_initializer_value = std::make_unique(); const ONNX_NAMESPACE::TensorProto* outer_initializer = parent_graph->GetInitializer(implicit_name, - /*check_outer_scope*/ true, - is_constant); + *outer_initializer_value, + is_constant, + /*check_outer_scope*/ true); outer_value_info->SetFlag(EpValueInfo::kIsOuterScope); if (outer_initializer != nullptr) { outer_value_info->SetFlag(is_constant ? EpValueInfo::kIsConstantInitializer : EpValueInfo::kIsOptionalGraphInput); } - // Temporary: Copy onnx::TensorProto into OrtValue objects owned by this EpGraph. - // TODO: Remove this logic once a separate PR that updates onnxruntime::Graph to store initializers as - // OrtValue instances is merged. + // Add the OrtValue if this is an initializer. if (outer_initializer != nullptr) { - auto initializer_value = std::make_unique(); - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), parent_graph->ModelPath(), - *outer_initializer, initializer_allocator, - *initializer_value)); - outer_scope_initializer_values.emplace(outer_value_info->GetName(), std::move(initializer_value)); + if (!outer_initializer_value->IsAllocated()) { + // onnxruntime::Graph does not have an OrtValue for this initializer, so create one from the TensorProto. + // This should only happen for small initializers that are needed for ONNX shape inferencing. + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), parent_graph->ModelPath(), + *outer_initializer, initializer_allocator, + *outer_initializer_value)); + } + outer_scope_initializer_values.emplace(outer_value_info->GetName(), std::move(outer_initializer_value)); } } } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 358379a9b5854..ba1c4c1ee2b45 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -164,6 +164,9 @@ struct EpNode : public OrtNode { // Gets the node's implicit inputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. Status GetImplicitInputs(std::unique_ptr& inputs) const override; + // Gets the node's attributes as OrtOpAttr instances wrapped in an OrtArrayOfConstObjects. + Status GetAttributes(std::unique_ptr& attrs) const override; + // Gets the subgraphs contained by this node. Status GetSubgraphs(std::unique_ptr& subgraphs) const override; @@ -186,6 +189,9 @@ struct EpNode : public OrtNode { // Helper that returns this node's outputs as a span of EpValueInfo pointers. gsl::span GetOutputsSpan() const; + // Helper that gets the node's attributes by name. + const OrtOpAttr* GetAttribute(const std::string& name) const; + private: // Back pointer to containing graph. Useful when traversing through nested subgraphs. // Will be nullptr if the EpNode was created without an owning graph. @@ -196,6 +202,9 @@ struct EpNode : public OrtNode { InlinedVector inputs_; InlinedVector outputs_; + std::unordered_map> attributes_map_; + std::vector attributes_; + std::vector implicit_inputs_; std::vector subgraphs_; }; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index c9856b9964495..ca40bad2b4250 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3831,8 +3831,8 @@ const ONNX_NAMESPACE::TensorProto* Graph::GetInitializer(const std::string& init return initializer; } -const ONNX_NAMESPACE::TensorProto* Graph::GetInitializer(const std::string& initializer_name, bool check_outer_scope, - bool& is_constant) const { +const ONNX_NAMESPACE::TensorProto* Graph::GetInitializer(const std::string& initializer_name, OrtValue& value, + bool& is_constant, bool check_outer_scope) const { const ONNX_NAMESPACE::TensorProto* initializer = nullptr; if (GetInitializedTensor(initializer_name, initializer)) { if (CanOverrideInitializer()) { @@ -3844,10 +3844,13 @@ const ONNX_NAMESPACE::TensorProto* Graph::GetInitializer(const std::string& init } else { is_constant = true; } + + auto it = ortvalue_initializers_.find(initializer_name); + value = (it != ortvalue_initializers_.end()) ? it->second : OrtValue(); } else if (check_outer_scope && IsSubgraph()) { // make sure there's not a local value with the same name. if there is it shadows any initializer in outer scope. if (IsOuterScopeValue(initializer_name)) { - initializer = parent_graph_->GetInitializer(initializer_name, check_outer_scope, is_constant); + initializer = parent_graph_->GetInitializer(initializer_name, value, is_constant, check_outer_scope); } } @@ -4342,7 +4345,8 @@ Status Graph::AddExternalInitializersToGraphProtoImpl( std::vector raw_data; ORT_RETURN_IF_ERROR(utils::UnpackInitializerData(initializer, model_path, raw_data)); size_t tensor_bytes_size = raw_data.size(); - if (tensor_bytes_size < model_saving_options.initializer_size_threshold) { + if (model_saving_options.force_embed_external_ini || + tensor_bytes_size < model_saving_options.initializer_size_threshold) { *output_proto = initializer; // Data with size above the threshold is written into the new external initializer file // Data with size below the threshold should be kept inside the new model file @@ -4438,25 +4442,31 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers( const std::filesystem::path modified_external_file_path = model_file_path.parent_path() / external_file_path; const auto& model_path = ModelPath(); - // Create the external file. - std::ofstream external_stream(modified_external_file_path, std::ofstream::out | std::ofstream::binary); - auto const external_empty_pos = external_stream.tellp(); - ORT_ENFORCE(external_stream.is_open(), "Failed to open for writing:", modified_external_file_path); + std::ofstream external_stream; + std::streampos external_empty_pos; int64_t external_offset = 0; + if (!model_saving_options.force_embed_external_ini) { + // Create the external file. + external_stream.open(modified_external_file_path, std::ofstream::out | std::ofstream::binary); + external_empty_pos = external_stream.tellp(); + ORT_ENFORCE(external_stream.is_open(), "Failed to open for writing:", modified_external_file_path); + } ORT_THROW_IF_ERROR(AddExternalInitializersToGraphProtoImpl(model_path, external_file_path, modified_external_file_path, model_saving_options, result, external_stream, external_offset)); - if (!external_stream.flush()) { - ORT_THROW("Failed to flush file with external initializers: ", modified_external_file_path); - } + if (!model_saving_options.force_embed_external_ini) { + if (!external_stream.flush()) { + ORT_THROW("Failed to flush file with external initializers: ", modified_external_file_path); + } - // Delete if the external data file is empty - if (external_empty_pos == external_stream.tellp()) { - external_stream.close(); - std::remove(modified_external_file_path.string().c_str()); + // Delete if the external data file is empty + if (external_empty_pos == external_stream.tellp()) { + external_stream.close(); + std::remove(modified_external_file_path.string().c_str()); + } } return result; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 5860269193b94..1fb6018164977 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -114,6 +114,11 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting the implicit inputs for OrtNode"); } + Status GetAttributes(std::unique_ptr& /*attrs*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); + } + Status GetSubgraphs(std::unique_ptr& /*subgraphs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmrelaxedsimd.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmrelaxedsimd.cpp index 72f0f5d8a4dd4..56c67aa4feca8 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmrelaxedsimd.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmrelaxedsimd.cpp @@ -89,7 +89,7 @@ MlasGemmQuantCopyPackA( { MLAS_UNREFERENCED_PARAMETER(AIsSigned); const v128_t ZeroVector = wasm_i64x2_const(0, 0); - const v128_t OnesWordBroadcast = wasm_i16x8_splat(1); + const v128_t OnesByteBroadcast = wasm_i8x16_splat(1); uint8_t PaddedMatrixAData[8] = { 0 }; // @@ -109,19 +109,23 @@ MlasGemmQuantCopyPackA( // but CountK is aligned up to a multiple of 4 to maintain 32-bit // alignment. All extra bytes are zero-padded. // - // Zero extend the source bytes to 16-bits and accumulate - // into an intermediate per-row - // accumulator. CountK cannot be greater than 128 to avoid overflowing - // these signed 16-bit accumulators. - // + // Accumulate into an intermediate per-row accumulator. - while (k >= 8) { + while (k >= 16) { - v128_t Bytes = wasm_v128_load64_zero(&a[0]); - v128_t Words = wasm_i8x16_unpacklo_relaxed(Bytes, ZeroVector); + v128_t Bytes = wasm_v128_load(&a[0]); + ReductionVector = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(OnesByteBroadcast, Bytes, ReductionVector); - ReductionVector = wasm_i16x8_add(ReductionVector, Words); + wasm_v128_store(&D[0], Bytes); + a += 16; + D += 16; + k -= 16; + } + + if (k >= 8) { + v128_t Bytes = wasm_v128_load64_zero(&a[0]); + ReductionVector = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(OnesByteBroadcast, Bytes, ReductionVector); wasm_v128_store64_lane(&D[0], Bytes, 0); a += 8; @@ -145,9 +149,7 @@ MlasGemmQuantCopyPackA( } while (padded < padded_end); v128_t Bytes = wasm_v128_load64_zero(PaddedMatrixAData); - v128_t Words = wasm_i8x16_unpacklo_relaxed(Bytes, ZeroVector); - - ReductionVector = wasm_i16x8_add(ReductionVector, Words); + ReductionVector = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(OnesByteBroadcast, Bytes, ReductionVector); // // Copy quads of 8-bit values from the vector to the packed @@ -165,7 +167,6 @@ MlasGemmQuantCopyPackA( // Reduce the partial accumulators. // - ReductionVector = wasm_i32x4_dot_i16x8(ReductionVector, OnesWordBroadcast); ReductionVector = wasm_i32x4_add(ReductionVector, wasm_i32x4_shuffle(ReductionVector, wasm_i32x4_splat(0), 2, 3, 2, 3)); ReductionVector = wasm_i32x4_add(ReductionVector, diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.cc b/onnxruntime/core/optimizer/graph_transformer_mgr.cc index dd8c5f8a96c17..16e3f4f1ec9ce 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.cc +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.cc @@ -4,6 +4,9 @@ #include "core/optimizer/graph_transformer_mgr.h" #include "core/optimizer/rule_based_graph_transformer.h" +#include +#include + using namespace onnxruntime; using namespace ::onnxruntime::common; @@ -60,7 +63,8 @@ void GraphTransformerManager::ClearGraphModified(void) { common::Status GraphTransformerManager::Register(std::unique_ptr transformer, TransformerLevel level) { const auto& name = transformer->Name(); - if (transformers_info_.find(name) != transformers_info_.end()) { + const auto& registered = level_to_transformer_map_[level]; + if (std::find(registered.begin(), registered.end(), transformer) != registered.end()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This transformer is already registered " + name); } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 062cbce6387e6..4edf804e48aaa 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -215,6 +215,7 @@ InlinedVector> GenerateTransformers( const InlinedHashSet cpu_acl_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kAclExecutionProvider}; #endif + const InlinedHashSet no_limit_empty_ep_list = {}; const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); @@ -243,7 +244,6 @@ InlinedVector> GenerateTransformers( for (const auto& p : session_options.initializers_to_share_map) { excluded_initializers.insert(p.first); } - const InlinedHashSet no_limit_empty_ep_list = {}; transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, excluded_initializers)); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq, @@ -363,14 +363,13 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cuda_eps)); - + // Run MatMulAddFusion again after *AttentionFusion transforms with `preserve_attention_pattern = false`, + // to cleanup the remaining MatMul-Add that were part of the attention pattern but not detected or fused. + transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, false)); transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index c2cdf360ad986..f611c992e0f57 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -30,93 +30,29 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose); } -#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS -// TODO(mtavenrath) generate list from registered kernels using nhwc domain -const std::unordered_set& GetCUDALayoutSensitiveOps() { - static std::unordered_set cuda_nhwc_ops = []() { - return std::unordered_set{ - "BatchNormalization", - "Conv", - "ConvTranspose", - "GlobalMaxPool", - "MaxPool", - "GlobalAveragePool", - "AveragePool", - "GridSample", - "DepthToSpace", - "SpaceToDepth", - "LRN"}; - }(); - return cuda_nhwc_ops; -} -#endif - /// /// Default function for checking if a node should have its layout changed. Allows EP specific adjustments to the /// default set of layout sensitive operators if required. -/// -/// Longer term, if required, the EP API could allow the EP to provide a delegate to plugin EP specific logic so we -/// don't hardcode it here. /// +/// The EP instance. /// Node to check /// true if the node should have its layout converted to NHWC. -bool ConvertNodeLayout(const api::NodeRef& node) { +bool ShouldConvertNodeLayoutToNhwc(const IExecutionProvider& execution_provider, const api::NodeRef& node) { // skip if op is not an ONNX or contrib op - auto domain = node.Domain(); + const auto domain = node.Domain(); if (domain != kOnnxDomain && domain != kMSDomain) { return false; } - const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); - - // handle special cases -#if defined(USE_JSEP) - // TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed - if (node.GetExecutionProviderType() == kJsExecutionProvider) { - if (node.OpType() == "Resize") { - // leave Resize as-is pending bugfix for NHWC implementation. this means the node will remain in the ONNX domain - // with the original input layout. - return false; - } + const auto op_type = node.OpType(); + if (auto should_convert_from_ep = execution_provider.ShouldConvertDataLayoutForOp(domain, op_type, DataLayout::NHWC); + should_convert_from_ep.has_value()) { + return *should_convert_from_ep; } -#endif -// NHWC for Resize operator is not implemented on kWebGpuExecutionProvider -#if defined(USE_WEBGPU) - if (node.GetExecutionProviderType() == kWebGpuExecutionProvider) { - if (node.OpType() == "Resize") { - return false; - } - } -#endif - -// TODO: We don't need to check USE_CUDA || USE_CUDA_PROVIDER_INTERFACE in this function because we're already -// checking if the node is assigned to the desired EP (e.g., CUDA EP). We should only need to check -// ENABLE_CUDA_NHWC_OPS. -#if (defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE)) && ENABLE_CUDA_NHWC_OPS - if (node.GetExecutionProviderType() == kCudaExecutionProvider) { - if (layout_sensitive_ops.count(node.OpType())) { - const auto& cuda_nhwc_ops = GetCUDALayoutSensitiveOps(); - if (!cuda_nhwc_ops.count(node.OpType())) { - return false; - } - } - } -#endif - -// TODO: We don't really need EP pre-processor macros in this function because we're already checking if the -// node is assigned to the desired EP (e.g., QNN EP). There's nothing about this code that absolutely requires -// conditional compilation. -#if defined(USE_QNN) || defined(USE_QNN_PROVIDER_INTERFACE) - if (node.GetExecutionProviderType() == kQnnExecutionProvider) { - if (node.OpType() == "Upsample") { - // Upsample is translated to QNN's Resize, which requires the NHWC layout for processing. - return true; - } - } -#endif - - return layout_sensitive_ops.count(node.OpType()) != 0; + const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); + const auto op_identifier = MakeORTLayoutSensitiveOpId(domain, op_type); + return layout_sensitive_ops.find(op_identifier) != layout_sensitive_ops.end(); } } // namespace @@ -126,25 +62,37 @@ bool ConvertNodeLayout(const api::NodeRef& node) { // Once all the layout sensitive ops requested by the EP are wrapped the transpose optimizer will attempt to remove // as many of the layout transposes as possible. const std::unordered_set& GetORTLayoutSensitiveOps() { - static std::unordered_set ort_layout_sensitive_ops = []() { - const auto& layout_sensitive_ops = onnx_transpose_optimization::GetLayoutSensitiveOps(); + static const std::unordered_set ort_layout_sensitive_ops = []() { + const auto& layout_sensitive_onnx_ops = onnx_transpose_optimization::GetLayoutSensitiveOps(); + + // Define a static local string array so we can refer to the elements with string_views. + static const std::string layout_sensitive_contrib_ops[]{ + MakeORTLayoutSensitiveOpId(kMSDomain, "FusedConv"), + MakeORTLayoutSensitiveOpId(kMSDomain, "GridSample"), + MakeORTLayoutSensitiveOpId(kMSDomain, "QLinearAveragePool"), + MakeORTLayoutSensitiveOpId(kMSDomain, "QLinearGlobalAveragePool"), + }; + std::unordered_set ort_specific_ops = { - "FusedConv", - "QLinearAveragePool", - "QLinearGlobalAveragePool", // Whilst the ONNX spec doesn't specify a layout for Resize, we treat it as layout sensitive by default // as EPs tend to only support one layout. "Resize", }; - ort_specific_ops.insert(layout_sensitive_ops.cbegin(), layout_sensitive_ops.cend()); + ort_specific_ops.insert(std::begin(layout_sensitive_onnx_ops), std::end(layout_sensitive_onnx_ops)); + ort_specific_ops.insert(std::begin(layout_sensitive_contrib_ops), std::end(layout_sensitive_contrib_ops)); return ort_specific_ops; }(); return ort_layout_sensitive_ops; } +// "op_type" if from ONNX domain, "domain:op_type" otherwise. +std::string MakeORTLayoutSensitiveOpId(std::string_view domain, std::string_view op_type) { + return (domain == kOnnxDomain) ? std::string(op_type) : MakeString(domain, ":", op_type); +} + Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvider& execution_provider, AllocatorPtr cpu_allocator, const DebugGraphFn& debug_graph_fn) { @@ -159,7 +107,7 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid continue; } - if (ConvertNodeLayout(*node)) { + if (ShouldConvertNodeLayoutToNhwc(execution_provider, *node)) { // domain kMSInternalNHWCDomain uses OpType "Conv" for both Conv and FusedConv. // So, change the OpType to "Conv" for FusedConv. std::string_view op_type = node->OpType() == "FusedConv" ? "Conv" : node->OpType(); diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.h b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.h index 23971975ecc3e..eaf0d3cc221f2 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.h +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.h @@ -68,10 +68,19 @@ bool IsSupportedOpset(const Graph& graph); /// Gets a list of layout sensitive ops for ORT. This list contains ONNX standard defined /// layout sensitive ops + contrib ops + ops which are not layout sensitive but are treated as /// layout sensitive by ORT EPs (example Resize). +/// +/// Note: The format of the returned op identifiers is "" for ops in the ONNX domain and +/// ":" for ops in other domains. `MakeORTLayoutSensitiveOpId()` can be used to +/// create an op identifier with this format. /// -/// unordered set of op_types which are layout sensitive +/// set of op identifiers which are layout sensitive const std::unordered_set& GetORTLayoutSensitiveOps(); +/// +/// Creates an op identifier compatible with `GetORTLayoutSensitiveOps()`. +/// +std::string MakeORTLayoutSensitiveOpId(std::string_view domain, std::string_view op_type); + /// /// Inserts transposes around op inputs/outputs. Alternatively transposes initializers or uses existing Transpose /// nodes if possible. Populates shape information on affected node inputs/outputs to reflect the change. diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index a6c422e59aeef..761fe1854274e 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -1,11 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/inlined_containers.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/graph_transformer_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/matmul_add_fusion.h" -#include "core/graph/graph_utils.h" -#include "core/framework/tensorprotoutils.h" -#include + +#include +#include +#include +#include using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; @@ -128,7 +134,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, int64_t m = 0, k = 0, n = 0; if (need_reshape) { // Only check and skip Attention pattern here because normally input to Attention is 4D. - if (attn_pattern_cache.IsAttentionPattern(graph, matmul_node, add_node)) { + if (preserve_attention_pattern_ && attn_pattern_cache.IsAttentionPattern(graph, matmul_node, add_node)) { continue; } diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.h b/onnxruntime/core/optimizer/matmul_add_fusion.h index 007fb18f00a1c..dc835197a54c5 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.h +++ b/onnxruntime/core/optimizer/matmul_add_fusion.h @@ -9,10 +9,15 @@ namespace onnxruntime { class MatMulAddFusion : public GraphTransformer { public: - MatMulAddFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("MatMulAddFusion", compatible_execution_providers) {} + MatMulAddFusion(const InlinedHashSet& compatible_execution_providers = {}, + const bool preserve_attention_pattern = true) noexcept + : GraphTransformer("MatMulAddFusion", compatible_execution_providers), + preserve_attention_pattern_(preserve_attention_pattern) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + private: + bool preserve_attention_pattern_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index 4b9259c080da3..7ceb61b4aabc5 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -276,14 +276,6 @@ Status ProcessNode( kMSDomain); matmul_scale_node.SetExecutionProviderType(node.GetExecutionProviderType()); -#ifdef USE_ROCM - // forward the __backwardpass, if present - auto& attrs = node.GetAttributes(); - if (attrs.count("__backwardpass")) { - matmul_scale_node.AddAttribute("__backwardpass", static_cast(attrs.at("__backwardpass").i())); - } -#endif - { InlinedVector> nodes_to_remove{node}; diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index 9edd1ca4230f6..ad678d5384c66 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -412,13 +412,6 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ matmul_node.AddAttribute("alpha", alpha); // Assign provider to this new node. Provider should be same as the provider for old node. matmul_node.SetExecutionProviderType(node.GetExecutionProviderType()); -#ifdef USE_ROCM - // forward the __backwardpass, if present - auto& attrs = node.GetAttributes(); - if (attrs.count("__backwardpass")) { - matmul_node.AddAttribute("__backwardpass", static_cast(attrs.at("__backwardpass").i())); - } -#endif graph_utils::FinalizeNodeFusion(graph, matmul_node, node); diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 888ff1d0aa91e..6754e2471f52c 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -107,4 +107,13 @@ void Telemetry::LogDriverInfoEvent(const std::string_view device_class, ORT_UNUSED_PARAMETER(driver_versions); } +void Telemetry::LogAutoEpSelection(uint32_t session_id, const std::string& selection_policy, + const std::vector& requested_execution_provider_ids, + const std::vector& available_execution_provider_ids) const { + ORT_UNUSED_PARAMETER(session_id); + ORT_UNUSED_PARAMETER(selection_policy); + ORT_UNUSED_PARAMETER(requested_execution_provider_ids); + ORT_UNUSED_PARAMETER(available_execution_provider_ids); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index 99199c34f0464..0103588f0e0d7 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -78,6 +78,10 @@ class Telemetry { const std::wstring_view& driver_names, const std::wstring_view& driver_versions) const; + virtual void LogAutoEpSelection(uint32_t session_id, const std::string& selection_policy, + const std::vector& requested_execution_provider_ids, + const std::vector& available_execution_provider_ids) const; + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Telemetry); }; diff --git a/onnxruntime/core/platform/windows/debug_alloc.cc b/onnxruntime/core/platform/windows/debug_alloc.cc index fed61854860f0..ad26280a90ecb 100644 --- a/onnxruntime/core/platform/windows/debug_alloc.cc +++ b/onnxruntime/core/platform/windows/debug_alloc.cc @@ -253,7 +253,8 @@ Memory_LeakCheck::~Memory_LeakCheck() { string.find("testing::internal::Mutex::ThreadSafeLazyInit") == std::string::npos && string.find("testing::internal::ThreadLocalRegistryImpl::GetThreadLocalsMapLocked") == std::string::npos && string.find("testing::internal::ThreadLocalRegistryImpl::GetValueOnCurrentThread") == std::string::npos && - string.find("PyInit_onnxruntime_pybind11_state") == std::string::npos) { + string.find("PyInit_onnxruntime_pybind11_state") == std::string::npos && + string.find("google::protobuf::internal::InitProtobufDefaultsSlow") == std::string::npos) { if (leaked_bytes == 0) DebugPrint("\n-----Starting Heap Trace-----\n\n"); diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index fdd4fa5b815d6..3908af40f962b 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -336,6 +336,7 @@ std::unordered_map GetDeviceInfoD3D12() { info.device_id = desc.DeviceId; info.description = std::wstring(desc.Description); + info.metadata[L"LUID"] = std::to_wstring(key); info.metadata[L"DxgiAdapterNumber"] = std::to_wstring(i); info.metadata[L"DxgiVideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB"; } @@ -436,8 +437,6 @@ std::unordered_map GetDeviceInfoDxcore() { continue; } - DeviceInfo& info = device_info[key]; - // Get hardware identifying information DXCoreHardwareIDParts idParts = {}; if (!adapter->IsPropertySupported(DXCoreAdapterProperty::HardwareIDParts) || @@ -445,8 +444,10 @@ std::unordered_map GetDeviceInfoDxcore() { continue; // also need valid ids } + DeviceInfo& info = device_info[key]; info.vendor_id = idParts.vendorID; info.device_id = idParts.deviceID; + info.metadata[L"LUID"] = std::to_wstring(key); // Is this a GPU or NPU if (adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)) { diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 39cd805b96d6e..47c9d0d75df16 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -433,4 +433,49 @@ void WindowsTelemetry::LogDriverInfoEvent(const std::string_view device_class, c TraceLoggingWideString(driver_versions.data(), "driverVersions")); } +void WindowsTelemetry::LogAutoEpSelection(uint32_t session_id, const std::string& selection_policy, + const std::vector& requested_execution_provider_ids, + const std::vector& available_execution_provider_ids) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + // Build requested execution provider string + std::string requested_execution_provider_string; + bool first = true; + for (const auto& ep_id : requested_execution_provider_ids) { + if (first) { + first = false; + } else { + requested_execution_provider_string += ','; + } + requested_execution_provider_string += ep_id; + } + + // Build available execution provider string + std::string available_execution_provider_string; + first = true; + for (const auto& ep_id : available_execution_provider_ids) { + if (first) { + first = false; + } else { + available_execution_provider_string += ','; + } + available_execution_provider_string += ep_id; + } + + TraceLoggingWrite(telemetry_provider_handle, + "EpAutoSelection", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId"), + TraceLoggingString(selection_policy.c_str(), "selectionPolicy"), + TraceLoggingString(requested_execution_provider_string.c_str(), "requestedExecutionProviderIds"), + TraceLoggingString(available_execution_provider_string.c_str(), "availableExecutionProviderIds")); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 1b4cc7b5408a5..787c8ba2d5e7f 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -71,6 +71,10 @@ class WindowsTelemetry : public Telemetry { const std::wstring_view& driver_names, const std::wstring_view& driver_versions) const override; + void LogAutoEpSelection(uint32_t session_id, const std::string& selection_policy, + const std::vector& requested_execution_provider_ids, + const std::vector& available_execution_provider_ids) const override; + using EtwInternalCallback = std::function; diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index badbf1f914fd2..b3f62bd13a24d 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -81,7 +81,7 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status NonMaxSuppressionBase__PrepareCompute(OpKernelContext* ctx, PrepareContext& pc) override { return NonMaxSuppressionBase::PrepareCompute(ctx, pc); } Status NonMaxSuppressionBase__GetThresholdsFromInputs(const PrepareContext& pc, int64_t& max_output_boxes_per_class, float& iou_threshold, float& score_threshold) override { return NonMaxSuppressionBase::GetThresholdsFromInputs(pc, max_output_boxes_per_class, iou_threshold, score_threshold); } -#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) // From cpu/tensor/size.h (direct) Status Size__Compute(const Size* p, OpKernelContext* context) override { return p->Size::Compute(context); } // From cpu/tensor/scatter_nd.h (direct) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index 170aa3a2d8d0c..a672dff0f4c01 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -1,18 +1,18 @@ /** -* Copyright (c) 2016-present, Facebook, Inc. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ /* Modifications Copyright (c) Microsoft. */ @@ -25,11 +25,11 @@ namespace onnxruntime { namespace cuda { -__device__ __forceinline__ void atomic_add(float *address, float value) { - atomicAdd(address, value); +__device__ __forceinline__ void atomic_add(float* address, float value) { + atomicAdd(address, value); } -__device__ __forceinline__ void atomic_add(double *address, double value) { +__device__ __forceinline__ void atomic_add(double* address, double value) { #if __CUDA_ARCH__ < 600 unsigned long long* raw_address = reinterpret_cast(address); unsigned long long raw_old_value = 0ULL; @@ -40,7 +40,7 @@ __device__ __forceinline__ void atomic_add(double *address, double value) { do { *p_old_value = *address; *p_new_value = *address + value; - seen_old_value = atomicCAS(raw_address, raw_old_value, raw_new_value); + seen_old_value = atomicCAS(raw_address, raw_old_value, raw_new_value); } while (seen_old_value != raw_old_value); #else atomicAdd(address, value); @@ -50,7 +50,7 @@ __device__ __forceinline__ void atomic_add(double *address, double value) { // // ref: https://github.com/pytorch/pytorch/blob/master/aten/src/THC/THCAtomics.cuh // -__device__ __forceinline__ void atomic_add(half *address, half value) { +__device__ __forceinline__ void atomic_add(half* address, half value) { #if __CUDA_ARCH__ < 700 unsigned int* base_address = (unsigned int*)((char*)address - ((size_t)address & 2)); unsigned int old = *base_address; @@ -90,7 +90,7 @@ __device__ __forceinline__ void atomic_add(BFloat16* address, BFloat16 value) { // But since the signature is different, we can change it for specific Op kernel once we find it is slow. // TODO: need to add same logic for BF16. template -__device__ __forceinline__ void AtomicAdd(T *start_addr, size_t index, const size_t numel, T value) { +__device__ __forceinline__ void AtomicAdd(T* start_addr, size_t index, const size_t numel, T value) { ORT_UNUSED_PARAMETER(numel); atomic_add(start_addr + index, value); } @@ -128,42 +128,42 @@ __device__ __forceinline__ void AtomicAdd(half* start_addr, size_t index, template class AtomicCasType; -template<> +template <> class AtomicCasType { public: using type = unsigned short int; static const unsigned int mask = 0xffu; }; -template<> +template <> class AtomicCasType { public: using type = unsigned short int; static const unsigned int mask = 0xffffu; }; -template<> +template <> class AtomicCasType { public: using type = unsigned int; static const unsigned int mask = 0xffffffffu; }; -template<> +template <> class AtomicCasType { public: using type = unsigned long long int; static const unsigned int mask = 0xffffffffu; }; -template<> +template <> class AtomicCasType { public: using type = int; static const unsigned int mask = 0xffffffffu; }; -template<> +template <> class AtomicCasType { public: using type = unsigned long long int; @@ -184,168 +184,168 @@ class AtomicCasType { // return a + b; // } // This function becomes atomic_add for int8_t. -template +template __device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { - // Assert to ensure the following bit-wise manipulation is correct. - static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, - "ValueType must be 1-byte, 2-byte or 4-byte large."); - // Number of bytes to the lower 4-byte aligned address. - // If the current address is b1010"10", then offset = b10 = 2, - // which means the current address is 2 bytes away from - // the lower 4-byte aligned address b1010"00". - size_t offset = (size_t)address & 3; - // Find an new 4-byte aligned address `address_as_ui` lower than - // or equal to `address`. Lower than `address` so that the actual - // int8_t byte is in the 4-byte word that we load. + // Assert to ensure the following bit-wise manipulation is correct. + static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, + "ValueType must be 1-byte, 2-byte or 4-byte large."); + // Number of bytes to the lower 4-byte aligned address. + // If the current address is b1010"10", then offset = b10 = 2, + // which means the current address is 2 bytes away from + // the lower 4-byte aligned address b1010"00". + size_t offset = (size_t)address & 3; + // Find an new 4-byte aligned address `address_as_ui` lower than + // or equal to `address`. Lower than `address` so that the actual + // int8_t byte is in the 4-byte word that we load. + // + // This address has the following properties: + // 1. It is 4-byte aligned. + // 2. It is lower than or equal to `address`. + // 3. De-referencing this address may return + // a uint32_t value that contains the same int8_t + // value indicated by `address`. + // + // E.g., + // address = b101010 + // offset = b101010 & b000011 = b10 = 2 + // (char*)address - offset => (char*)b101010 - b000010 => b1010"00", + // which is (32-bit aligned). + uint32_t* address_as_ui = (uint32_t*)((char*)address - offset); + uint32_t old = *address_as_ui; + // E.g., offset = 2. + // address_as_ui is an address 2 bytes lower than `address`. + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // ^ ^ ^ + // | | | + // | address <--- offset * 8 (bit)-----> address_as_ui + // | ^ + // | | + // ------------------------- *address_as_ui ----------------------- + // + // This visualization shows + // 1. the 32-bit word at address_as_ui. + // 2. the gap between address_as_ui and address. + // 3. *address_as_ui contains the int8_t value at `address`. + uint32_t shift = offset * 8; + uint32_t old_byte; + uint32_t newval; + uint32_t assumed; + do { + assumed = old; + // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so + // we want to select the 3rd byte (byte 2 below) from the word. + // + // Journey of a 32-bit value: + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // | + // | old >> offset * 8, where offset = 2. + // | Effectively, push lower two bytes + // | out of the word. + // V // - // This address has the following properties: - // 1. It is 4-byte aligned. - // 2. It is lower than or equal to `address`. - // 3. De-referencing this address may return - // a uint32_t value that contains the same int8_t - // value indicated by `address`. + // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 ..... // - // E.g., - // address = b101010 - // offset = b101010 & b000011 = b10 = 2 - // (char*)address - offset => (char*)b101010 - b000010 => b1010"00", - // which is (32-bit aligned). - uint32_t * address_as_ui = (uint32_t*)((char*)address - offset); - uint32_t old = *address_as_ui; - // E.g., offset = 2. - // address_as_ui is an address 2 bytes lower than `address`. + // | apply bit-wise AND, + // | & 0xff (i.e., & b11111111), + // | so that we only keep + // | the byte of interest. + // | Otherwise, overflow may + // | happen when casting this + // | 32-bit value to int8_t. + // V // + // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... + old_byte = (old >> shift) & AtomicCasType::mask; + // Compute new int8_t value and store it to newrawvalue. + // Journey of a 32-bit value (cont'd): + // + // newrawvalue + // ... new byte 2 ... + auto newrawvalue = func(val, reinterpret_cast(old_byte)); + // Put the new int8_t value back to 32-bit word. + // Also ensure that bits not occupied by the int8_t value are 0s. + // + // Journey of a 32-bit value (cont'd): + // + // reinterpret_cast(newrawvalue) + // random values | random values | random values | ... new byte 2 ... + // + // reinterpret_cast(newrawvalue) & AtomicCasType::mask + // 00000000 | 00000000 | 00000000 | ... new byte 2 ... + newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask; + // Journey of a 32-bit value (cont'd): + // + // old // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... - // ^ ^ ^ - // | | | - // | address <--- offset * 8 (bit)-----> address_as_ui - // | ^ - // | | - // ------------------------- *address_as_ui ----------------------- // - // This visualization shows - // 1. the 32-bit word at address_as_ui. - // 2. the gap between address_as_ui and address. - // 3. *address_as_ui contains the int8_t value at `address`. - uint32_t shift = offset * 8; - uint32_t old_byte; - uint32_t newval; - uint32_t assumed; - do { - assumed = old; - // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so - // we want to select the 3rd byte (byte 2 below) from the word. - // - // Journey of a 32-bit value: - // - // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... - // - // | - // | old >> offset * 8, where offset = 2. - // | Effectively, push lower two bytes - // | out of the word. - // V - // - // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 ..... - // - // | apply bit-wise AND, - // | & 0xff (i.e., & b11111111), - // | so that we only keep - // | the byte of interest. - // | Otherwise, overflow may - // | happen when casting this - // | 32-bit value to int8_t. - // V - // - // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... - old_byte = (old >> shift) & AtomicCasType::mask; - // Compute new int8_t value and store it to newrawvalue. - // Journey of a 32-bit value (cont'd): - // - // newrawvalue - // ... new byte 2 ... - auto newrawvalue = func(val, reinterpret_cast(old_byte)); - // Put the new int8_t value back to 32-bit word. - // Also ensure that bits not occupied by the int8_t value are 0s. - // - // Journey of a 32-bit value (cont'd): - // - // reinterpret_cast(newrawvalue) - // random values | random values | random values | ... new byte 2 ... - // - // reinterpret_cast(newrawvalue) & AtomicCasType::mask - // 00000000 | 00000000 | 00000000 | ... new byte 2 ... - newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask; - // Journey of a 32-bit value (cont'd): - // - // old - // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... - // - // 0x000000ff - // 00000000 | 00000000 | 00000000 | 11111111 - // - // 0x000000ff << shift - // 00000000 | 11111111 | 00000000 | 00000000 - // - // ~(0x000000ff << shift) - // 11111111 | 00000000 | 11111111 | 11111111 - // - // old & ~(0x000000ff << shift) - // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 ..... - // - // newval << shift - // 00000000 | ... new byte 2 ... | 00000000 | 00000000 - // - // (old & ~(0x000000ff << shift)) | (newval << shift) - // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... - newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift); - old = atomicCAS(address_as_ui, assumed, newval); - } while (assumed != old); + // 0x000000ff + // 00000000 | 00000000 | 00000000 | 11111111 + // + // 0x000000ff << shift + // 00000000 | 11111111 | 00000000 | 00000000 + // + // ~(0x000000ff << shift) + // 11111111 | 00000000 | 11111111 | 11111111 + // + // old & ~(0x000000ff << shift) + // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 ..... + // + // newval << shift + // 00000000 | ... new byte 2 ... | 00000000 | 00000000 + // + // (old & ~(0x000000ff << shift)) | (newval << shift) + // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... + newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift); + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); } // It accumulates `val` into the `address` using the `func`. // This function is thread-safe (i.e., atomic). -template +template __device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { ValueType observed = *address, assumed, new_value; using CasType = typename AtomicCasType::type; static_assert(sizeof(ValueType) == sizeof(CasType), - "ValueType and CasType must have the same size for calling atomicCAS."); + "ValueType and CasType must have the same size for calling atomicCAS."); auto address_as_cas_type = reinterpret_cast(address); do { - // Record the value used to compute new value. - assumed = observed; - - // Compute expected new value. - new_value = func(observed, val); - - // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. - // 4 - // 8 - auto observed_as_cas_type = *reinterpret_cast(&observed); - auto new_value_as_cas_type = *reinterpret_cast(&new_value); - - // Call atomicCAS as if the 2-byte type variables are all unsigned short int. - // 4 unsigned int (or int) - // 8 unsigned long long int - auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); - - // Cast the freshly observed value in memory back to the TwoByteType. - observed = *reinterpret_cast(&cas_observed_as_cas_type); - - // Two cases: - // 1. compare-and-swap success - // a. `address` holds `new_value` - // b. `observed` becomes the new value after the assignment. - // Thus, the following `observed != new_value` is false, - // and the loop terminates. - // 2. compare-and-swap fails - // a. `address` holds a value different from `observed`, thus, - // the `new_value` is stale. - // b. `observed` becomes the fresh value observed in `address`. - // Thus, the following (observed != new_value) is true, - // and the loop continues. In the next iteration, the - // `new_value` is computed again using the fresh `observed`. + // Record the value used to compute new value. + assumed = observed; + + // Compute expected new value. + new_value = func(observed, val); + + // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. + // 4 + // 8 + auto observed_as_cas_type = *reinterpret_cast(&observed); + auto new_value_as_cas_type = *reinterpret_cast(&new_value); + + // Call atomicCAS as if the 2-byte type variables are all unsigned short int. + // 4 unsigned int (or int) + // 8 unsigned long long int + auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); + + // Cast the freshly observed value in memory back to the TwoByteType. + observed = *reinterpret_cast(&cas_observed_as_cas_type); + + // Two cases: + // 1. compare-and-swap success + // a. `address` holds `new_value` + // b. `observed` becomes the new value after the assignment. + // Thus, the following `observed != new_value` is false, + // and the loop terminates. + // 2. compare-and-swap fails + // a. `address` holds a value different from `observed`, thus, + // the `new_value` is stale. + // b. `observed` becomes the fresh value observed in `address`. + // Thus, the following (observed != new_value) is true, + // and the loop continues. In the next iteration, the + // `new_value` is computed again using the fresh `observed`. } while (observed != assumed); } @@ -432,6 +432,5 @@ __device__ __forceinline__ void atomic_min(double* address, double value) { atomic_binary_func(address, value, MinFunc()); } - } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 2d2551a156099..9123c0bd76ec7 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -344,11 +344,9 @@ __device__ __inline__ double _Pow(double a, double b) { return pow(a, b); } template <> __device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (float)b)); } -#define ISNAN_HALF(v__) static_cast(*reinterpret_cast(&v__) & ~MLFloat16::kSignMask) \ - > MLFloat16::kPositiveInfinityBits +#define ISNAN_HALF(v__) static_cast(*reinterpret_cast(&v__) & ~MLFloat16::kSignMask) > MLFloat16::kPositiveInfinityBits -#define ISNAN_BFLOAT16(v__) static_cast(*reinterpret_cast(&v__) & ~BFloat16::kSignMask) \ - > BFloat16::kPositiveInfinityBits +#define ISNAN_BFLOAT16(v__) static_cast(*reinterpret_cast(&v__) & ~BFloat16::kSignMask) > BFloat16::kPositiveInfinityBits // CUDART_NAN_BF16 and CUDART_NAN_FP16 constants were only added in CUDA 12.2, // so define our own equivalent constants to support older versions. @@ -364,12 +362,12 @@ __device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; } template <> __device__ __inline__ float _Min(float a, float b) { - return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : (a < b ? a : b); } template <> __device__ __inline__ double _Min(double a, double b) { - return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : (a < b ? a : b); } template <> @@ -395,12 +393,12 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } template <> __device__ __inline__ float _Max(float a, float b) { - return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : (a > b ? a : b); } template <> __device__ __inline__ double _Max(double a, double b) { - return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : (a > b ? a : b); } template <> @@ -624,36 +622,34 @@ struct _IsNan { template <> struct _IsNan { __device__ __inline__ bool operator()(half a) const { - return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) - > MLFloat16::kPositiveInfinityBits; + return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) > MLFloat16::kPositiveInfinityBits; } }; template <> struct _IsNan { __device__ __inline__ bool operator()(BFloat16 a) const { - return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) - > BFloat16::kPositiveInfinityBits; + return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) > BFloat16::kPositiveInfinityBits; } }; #if !defined(DISABLE_FLOAT8_TYPES) -template<> +template <> struct _IsNan { __device__ __inline__ bool operator()(Float8E4M3FN a) const { return (*reinterpret_cast(&a) & 0x7f) == 0x7f; } }; -template<> +template <> struct _IsNan { __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const { return *reinterpret_cast(&a) == 0x80; } }; -template<> +template <> struct _IsNan { __device__ __inline__ bool operator()(Float8E5M2 a) const { uint8_t c = *reinterpret_cast(&a); @@ -661,7 +657,7 @@ struct _IsNan { } }; -template<> +template <> struct _IsNan { __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const { return *reinterpret_cast(&a) == 0x80; diff --git a/onnxruntime/core/providers/cuda/cu_inc/unary_elementwise_impl.cuh b/onnxruntime/core/providers/cuda/cu_inc/unary_elementwise_impl.cuh index 66113a1dffa11..c8ddbadb12fb2 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/unary_elementwise_impl.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/unary_elementwise_impl.cuh @@ -19,7 +19,7 @@ __global__ void _UnaryElementWise( InT value[NumElementsPerThread]; CUDA_LONG id = start; - #pragma unroll +#pragma unroll for (int i = 0; i < NumElementsPerThread; i++) { if (id < N) { value[i] = input_data[id]; @@ -28,7 +28,7 @@ __global__ void _UnaryElementWise( } id = start; - #pragma unroll +#pragma unroll for (int i = 0; i < NumElementsPerThread; i++) { if (id < N) { output_data[id] = functor(value[i]); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index e42422c1ce2b5..3aaeeee1cbc20 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -323,6 +323,37 @@ DataLayout CUDAExecutionProvider::GetPreferredLayout() const { return this->IsNHWCPreferred() ? DataLayout::NHWC : DataLayout::NCHW; } +std::optional CUDAExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain, + std::string_view node_op_type, + DataLayout target_data_layout) const { +#if defined(ENABLE_CUDA_NHWC_OPS) + if (target_data_layout != DataLayout::NHWC) { + return std::nullopt; + } + + // TODO(mtavenrath) generate list from registered kernels using nhwc domain + static const std::unordered_set cuda_nhwc_onnx_ops{ + "BatchNormalization", + "Conv", + "ConvTranspose", + "GlobalMaxPool", + "MaxPool", + "GlobalAveragePool", + "AveragePool", + "GridSample", + "DepthToSpace", + "SpaceToDepth", + "LRN", + }; + + return (node_domain == kOnnxDomain && cuda_nhwc_onnx_ops.find(node_op_type) != cuda_nhwc_onnx_ops.end()) || + (node_domain == kMSDomain && node_op_type == "GridSample"); + +#else // defined(ENABLE_CUDA_NHWC_OPS) + return std::nullopt; +#endif +} + CUDAExecutionProvider::~CUDAExecutionProvider() { // clean up thread local context caches { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index a75e81f1f0c6d..57fde8146d929 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -39,6 +39,10 @@ class CUDAExecutionProvider : public IExecutionProvider { DataLayout GetPreferredLayout() const override; + std::optional ShouldConvertDataLayoutForOp(std::string_view node_domain, + std::string_view node_op_type, + DataLayout target_data_layout) const override; + const void* GetExecutionHandle() const noexcept override { // The CUDA interface does not return anything interesting. return nullptr; diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index ab5d14cf22c18..28fc9cb8e530f 100644 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -24,8 +24,8 @@ namespace onnxruntime::cuda { -// When adding new supported NHWC operations make sure to also integrate them into: ConvertNodeLayout -// in onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +// When adding new supported NHWC operations make sure to also integrate them into +// CUDAExecutionProvider::ShouldConvertDataLayoutForOp() class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, float, BatchNormalization); class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, double, BatchNormalization); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index b160c14bdc359..b1df607e8ce99 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -319,12 +319,12 @@ struct CudaEpFactory : OrtEpFactory { ReleaseEp = ReleaseEpImpl; } - static const char* GetNameImpl(const OrtEpFactory* this_ptr) { + static const char* GetNameImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->ep_name.c_str(); } - static const char* GetVendorImpl(const OrtEpFactory* this_ptr) { + static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->vendor.c_str(); } @@ -334,7 +334,7 @@ struct CudaEpFactory : OrtEpFactory { size_t num_devices, OrtEpDevice** ep_devices, size_t max_ep_devices, - size_t* p_num_ep_devices) { + size_t* p_num_ep_devices) noexcept { size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); @@ -357,11 +357,11 @@ struct CudaEpFactory : OrtEpFactory { _In_ size_t /*num_devices*/, _In_ const OrtSessionOptions* /*session_options*/, _In_ const OrtLogger* /*logger*/, - _Out_ OrtEp** /*ep*/) { + _Out_ OrtEp** /*ep*/) noexcept { return CreateStatus(ORT_INVALID_ARGUMENT, "CUDA EP factory does not support this method."); } - static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) { + static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) noexcept { // no-op as we never create an EP here. } diff --git a/onnxruntime/core/providers/cuda/cuda_type_conversion.h b/onnxruntime/core/providers/cuda/cuda_type_conversion.h new file mode 100644 index 0000000000000..f118bc9c69bcc --- /dev/null +++ b/onnxruntime/core/providers/cuda/cuda_type_conversion.h @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#if defined(ENABLE_FP8) && !defined(DISABLE_FLOAT8_TYPES) +#include +#endif +#include +#include +#include "core/framework/int4.h" +#include "core/framework/float8.h" +#include "core/framework/float16.h" + +namespace onnxruntime { +namespace cuda { + +// Type mapping for ORT Type to CUDA Type +template +struct OrtToCudaType { + using type = T; + + static type FromFloat(float f) { + static_assert(std::is_same_v || std::is_same_v); + return static_cast(f); + } +}; + +template <> +struct OrtToCudaType { + using type = int8_t; +}; + +template <> +struct OrtToCudaType { + using type = uint8_t; +}; + +template <> +struct OrtToCudaType { + using type = __half; + static type FromFloat(float f) { + return type(f); + } +}; + +template <> +struct OrtToCudaType { + using type = __nv_bfloat16; + static type FromFloat(float f) { + return type(f); + } +}; + +#if defined(ENABLE_FP8) && !defined(DISABLE_FLOAT8_TYPES) +template <> +struct OrtToCudaType { + using type = __nv_fp8_e4m3; + static type FromFloat(float f) { + return type(f); + } +}; + +template <> +struct OrtToCudaType { + using type = __nv_fp8_e4m3; + static type FromFloat(float f) { + return type(f); + } +}; + +template <> +struct OrtToCudaType { + using type = __nv_fp8_e5m2; + static type FromFloat(float f) { + return type(f); + } +}; + +template <> +struct OrtToCudaType { + using type = __nv_fp8_e5m2; + static type FromFloat(float f) { + return type(f); + } +}; +#endif + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu index 1cc407efe8670..96e9e35e85f1a 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu @@ -90,17 +90,17 @@ namespace cuda { SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16) #define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_BWUZCSILHFD(x) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint8_t) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint16_t) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint32_t) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint64_t) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int8_t) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int16_t) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint8_t) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint16_t) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint32_t) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint64_t) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int8_t) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int16_t) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16) #define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZIL(x) \ diff --git a/onnxruntime/core/providers/cuda/math/clip_impl.cu b/onnxruntime/core/providers/cuda/math/clip_impl.cu index cf4d79eec0280..c6982de3efdc7 100644 --- a/onnxruntime/core/providers/cuda/math/clip_impl.cu +++ b/onnxruntime/core/providers/cuda/math/clip_impl.cu @@ -24,21 +24,21 @@ void ClipImpl(cudaStream_t stream, const T* input_data, T* output_data, const T* int blocksPerGrid = (int)(ceil(static_cast(count) / GridDim::maxThreadsPerBlock)); union ConstAliasUnion { - const T *t; - const CudaT *cudaT; - ConstAliasUnion(const T* _t) { t = _t;} + const T* t; + const CudaT* cudaT; + ConstAliasUnion(const T* _t) { t = _t; } }; union AliasUnion { - T *t; - CudaT *cudaT; - AliasUnion(T* _t) { t = _t;} + T* t; + CudaT* cudaT; + AliasUnion(T* _t) { t = _t; } }; _Clip<<>>(((union ConstAliasUnion)input_data).cudaT, ((union AliasUnion)output_data).cudaT, ((union ConstAliasUnion)min).cudaT, ((union ConstAliasUnion)max).cudaT, - *((union AliasUnion)&min_default).cudaT, - *((union AliasUnion)&max_default).cudaT, + *((union AliasUnion) & min_default).cudaT, + *((union AliasUnion) & max_default).cudaT, count); } diff --git a/onnxruntime/core/providers/cuda/math/cumsum_impl.cu b/onnxruntime/core/providers/cuda/math/cumsum_impl.cu index 8a657dd9dcdfa..ad530a4a6dfd8 100644 --- a/onnxruntime/core/providers/cuda/math/cumsum_impl.cu +++ b/onnxruntime/core/providers/cuda/math/cumsum_impl.cu @@ -33,7 +33,7 @@ __global__ void _CumSumKernel( if (!reverse && !exclusive) { start = 0; end = axis_dim; - + } else if (reverse && !exclusive) { start = axis_dim; end = input_dim_along_axis - 1; @@ -42,24 +42,23 @@ __global__ void _CumSumKernel( start = 0; end = axis_dim - 1; - } else { // reverse && exclusive + } else { // reverse && exclusive start = axis_dim + 1; end = input_dim_along_axis - 1; - } // count the number of elements to accumulate the sum int count = end - start + 1; if (count <= 0) { output_data[indices_index] = 0; - return; + return; } // adjust start index based on the above identified start dim value along the axis of interest int data_index = static_cast(indices_index) + (start - axis_dim) * input_stride_along_axis; T sum = 0; - // keep accumulating values from the start index for 'count' times and skip appropriately + // keep accumulating values from the start index for 'count' times and skip appropriately while (count != 0) { sum += input_data[data_index]; data_index += input_stride_along_axis; @@ -83,12 +82,12 @@ void CumSumImpl( int blocksPerGrid = static_cast((output_size + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock); _CumSumKernel<<>>(input_data, - input_dim_along_axis, - input_stride_along_axis, - output_data, - output_size, - exclusive, - reverse); + input_dim_along_axis, + input_stride_along_axis, + output_data, + output_size, + exclusive, + reverse); } } @@ -164,4 +163,3 @@ template void CumSumImpl( } // namespace cuda } // namespace onnxruntime - diff --git a/onnxruntime/core/providers/cuda/math/matmul_integer.cu b/onnxruntime/core/providers/cuda/math/matmul_integer.cu index 757e39dc44e5c..9b2b10e931b5d 100644 --- a/onnxruntime/core/providers/cuda/math/matmul_integer.cu +++ b/onnxruntime/core/providers/cuda/math/matmul_integer.cu @@ -29,9 +29,9 @@ __global__ void ReduceRowSumOnMatrixAKernel(const int8_t* matrix, int32_t* row_s Status ReduceRowSumOnMatrixA(cudaStream_t stream, const int8_t* matrix, int32_t* row_sum, const int8_t offset, const MatMulComputeHelper& helper) { for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) { ReduceRowSumOnMatrixAKernel(GridDim::maxThreadsPerBlock)><<(helper.M()), GridDim::maxThreadsPerBlock, 0, stream>>>(matrix + helper.LeftOffsets()[batch], - row_sum + batch * helper.M(), - offset, - static_cast(helper.K())); + row_sum + batch * helper.M(), + offset, + static_cast(helper.K())); } return CUDA_CALL(cudaGetLastError()); @@ -57,10 +57,10 @@ __global__ void ReduceColSumOnMatrixBKernel(const int8_t* matrix, int32_t* col_s Status ReduceColSumOnMatrixB(cudaStream_t stream, const int8_t* matrix, int32_t* col_sum, const int8_t offset, const MatMulComputeHelper& helper) { for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) { ReduceColSumOnMatrixBKernel(GridDim::maxThreadsPerBlock)><<(helper.N()), GridDim::maxThreadsPerBlock, 0, stream>>>(matrix + helper.RightOffsets()[batch], - col_sum + batch * helper.N(), - offset, - static_cast(helper.K()), - static_cast(helper.N())); + col_sum + batch * helper.N(), + offset, + static_cast(helper.K()), + static_cast(helper.N())); } return CUDA_CALL(cudaGetLastError()); diff --git a/onnxruntime/core/providers/cuda/math/softmax_blockwise_impl.cuh b/onnxruntime/core/providers/cuda/math/softmax_blockwise_impl.cuh index 8bb87035cdc6d..63e230085d05b 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_blockwise_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/softmax_blockwise_impl.cuh @@ -115,7 +115,7 @@ __device__ __forceinline__ AccumT blockReduce(AccumT* smem, AccumT val, AccumT blockVal = defaultVal; if (threadIdx.x == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < GPU_WARP_SIZE; ++i) { blockVal = r(blockVal, smem[i]); } @@ -158,7 +158,7 @@ __device__ __forceinline__ AccumT ilpReduce(int shift, for (; offset * ILP < (size - last); offset += blockDim.x) { *value = reinterpret_cast(data)[offset]; - #pragma unroll +#pragma unroll for (int j = 0; j < ILP; ++j) { threadVal = r(threadVal, v[j]); } @@ -213,7 +213,7 @@ __device__ __forceinline__ void WriteFpropResultsVectorized(int size, for (; offset * ILP < (size - last); offset += blockDim.x) { *in_value = reinterpret_cast(input)[offset]; - #pragma unroll +#pragma unroll for (int j = 0; j < ILP; ++j) { out_v[j] = epilogue(in_v[j]); } @@ -244,11 +244,11 @@ __device__ __forceinline__ void WriteFpropResults(int classes, for (; offset < classes - last; offset += blockDim.x * ILP) { scalar_t tmp[ILP]; - #pragma unroll +#pragma unroll for (int j = 0; j < ILP; ++j) { tmp[j] = input[offset + j * blockDim.x]; } - #pragma unroll +#pragma unroll for (int j = 0; j < ILP; ++j) { output[offset + j * blockDim.x] = epilogue(tmp[j]); } diff --git a/onnxruntime/core/providers/cuda/math/softmax_impl.cu b/onnxruntime/core/providers/cuda/math/softmax_impl.cu index a7da78fb4e146..04e66e9e1529e 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_impl.cu +++ b/onnxruntime/core/providers/cuda/math/softmax_impl.cu @@ -44,12 +44,12 @@ Status dispatch_warpwise_softmax_forward(Stream* ort_stream, output_t* dst, cons // there are 2 options to save one row of the input matrix: register or shared memory // when the number of elements is small, we use register; otherwise, we use shared memory; int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - if (log2_elements <= 10){ + if (log2_elements <= 10) { // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. // use 128 threads per block to maximimize gpu utilization threads_per_block = 128; shared_memory_size = 0; - } else{ + } else { // setting the number of threads per block to 32 will make index offset calculations easier, // under this setting, the cuda block number will be equal to batch size. threads_per_block = 32; @@ -63,47 +63,46 @@ Status dispatch_warpwise_softmax_forward(Stream* ort_stream, output_t* dst, cons dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { +#define LAUNCH_KERNEL(kernel_name, log2_elements_value) \ + kernel_name \ + <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); -#define LAUNCH_KERNEL(kernel_name, log2_elements_value) \ - kernel_name \ - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - -#define CASE_LOG2_ELEMENTS(log2_elements_value) \ - case log2_elements_value: { \ - if constexpr (log2_elements_value <= 10) { \ - LAUNCH_KERNEL(softmax_warp_forward, log2_elements_value) \ - } else { \ - LAUNCH_KERNEL(softmax_warp_forward_resource_efficient, log2_elements_value) \ - } \ +#define CASE_LOG2_ELEMENTS(log2_elements_value) \ + case log2_elements_value: { \ + if constexpr (log2_elements_value <= 10) { \ + LAUNCH_KERNEL(softmax_warp_forward, log2_elements_value) \ + } else { \ + LAUNCH_KERNEL(softmax_warp_forward_resource_efficient, log2_elements_value) \ + } \ } break - CASE_LOG2_ELEMENTS(0); - CASE_LOG2_ELEMENTS(1); - CASE_LOG2_ELEMENTS(2); - CASE_LOG2_ELEMENTS(3); - CASE_LOG2_ELEMENTS(4); - CASE_LOG2_ELEMENTS(5); - CASE_LOG2_ELEMENTS(6); - CASE_LOG2_ELEMENTS(7); - CASE_LOG2_ELEMENTS(8); - CASE_LOG2_ELEMENTS(9); - CASE_LOG2_ELEMENTS(10); - CASE_LOG2_ELEMENTS(11); // start to use softmax_warp_forward_resource_efficient instead of softmax_warp_forward for better performance + CASE_LOG2_ELEMENTS(0); + CASE_LOG2_ELEMENTS(1); + CASE_LOG2_ELEMENTS(2); + CASE_LOG2_ELEMENTS(3); + CASE_LOG2_ELEMENTS(4); + CASE_LOG2_ELEMENTS(5); + CASE_LOG2_ELEMENTS(6); + CASE_LOG2_ELEMENTS(7); + CASE_LOG2_ELEMENTS(8); + CASE_LOG2_ELEMENTS(9); + CASE_LOG2_ELEMENTS(10); + CASE_LOG2_ELEMENTS(11); // start to use softmax_warp_forward_resource_efficient instead of softmax_warp_forward for better performance #undef LAUNCH_KERNEL #undef CASE_LOG2_ELEMENTS - } // switch - } // else + } // switch + } // else return CUDA_CALL(cudaGetLastError()); } #define SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ - template Status dispatch_warpwise_softmax_forward(Stream* ort_stream, \ + template Status dispatch_warpwise_softmax_forward(Stream * ort_stream, \ output_t * dst, \ const input_t* src, \ int softmax_elements, \ int softmax_elements_stride, \ int batch_count); \ - template Status dispatch_warpwise_softmax_forward(Stream* ort_stream, \ + template Status dispatch_warpwise_softmax_forward(Stream * ort_stream, \ output_t * dst, \ const input_t* src, \ int softmax_elements, \ @@ -137,10 +136,10 @@ Status dispatch_blockwise_softmax_forward(Stream* ort_stream, output_t* output, #define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ template Status dispatch_blockwise_softmax_forward( \ - Stream* ort_stream, output_t * output, const input_t* src, int softmax_elements, \ + Stream * ort_stream, output_t * output, const input_t* src, int softmax_elements, \ int input_stride, int output_stride, int batch_count); \ template Status dispatch_blockwise_softmax_forward( \ - Stream* ort_stream, output_t * output, const input_t* src, int softmax_elements, \ + Stream * ort_stream, output_t * output, const input_t* src, int softmax_elements, \ int input_stride, int output_stride, int batch_count); SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float) diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.cuh b/onnxruntime/core/providers/cuda/math/topk_impl.cuh index 112566e54bbba..0cbc848be971c 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/topk_impl.cuh @@ -28,8 +28,8 @@ struct KV { #define BT GridDim::maxThreadsPerBlock #define ALIGN(N) static_cast(pow(2, ceil(log2(static_cast(N))))) -#define FROM(idx) (left_dim + (idx)*mid_dim + right_dim) -#define TO(idx) (left_dim * K / dimension + (idx)*mid_dim + right_dim) +#define FROM(idx) (left_dim + (idx) * mid_dim + right_dim) +#define TO(idx) (left_dim * K / dimension + (idx) * mid_dim + right_dim) #define TRIVIAL (1 == largest ? type_min : type_max) #define BIGGER(n, m) (n.key > m.key ? n : (n.key < m.key ? m : (n.val > m.val ? (1 == largest ? m : n) : (1 == largest ? n : m)))) #define SMALLER(n, m) (n.key < m.key ? n : (n.key > m.key ? m : (n.val < m.val ? (1 == largest ? m : n) : (1 == largest ? n : m)))) diff --git a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu index 9311f044f4ec5..49293ea7b3c3a 100644 --- a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu +++ b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu @@ -46,7 +46,7 @@ __global__ void MaxPoolWithIndexKernel( if (id >= output_size) return; auto compute_offset = - [height, width, depth, channels](int n_index, int c_index, int h_index, int w_index, int d_index) -> int64_t { + [height, width, depth, channels](int n_index, int c_index, int h_index, int w_index, int d_index) -> int64_t { if constexpr (Layout == LAYOUT_NCHW) { return (((n_index * channels + c_index) * height + h_index) * width + w_index) * depth + d_index; } else if constexpr (Layout == LAYOUT_NHWC) { @@ -108,8 +108,8 @@ __global__ void MaxPoolWithIndexKernel( // layouts, does it make sense to do an index conversion as well? // Storing indices in NHWC layout isn't critical as they are supposed to be used by Unpooling operations // which currently assume that indices reference to Tensors in NHWC layout. - int64_t id_nchw = - (((n_index * channels + c_index) * pooled_height + h_index) * pooled_width + w_index) * pooled_depth + d_index; + int64_t id_nchw = + (((n_index * channels + c_index) * pooled_height + h_index) * pooled_width + w_index) * pooled_depth + d_index; int64_t offset_nchw = (n_index * channels + c_index) * width * height * depth; p_indices[id_nchw] = (storage_order == 0) @@ -161,9 +161,9 @@ void MaxPoolWithIndex( int64_t stride_h = stride_shape[0]; int64_t stride_w = stride_shape.size() > 1 ? stride_shape[1] : 1; int64_t stride_d = stride_shape.size() > 2 ? stride_shape[2] : 1; - //pads in the format of [x1_begin, x2_begin...x1_end, x2_end,...], - //where xi_begin the number of pixels added at the beginning of axis i - //and xi_end, the number of pixels added at the end of axis i. + // pads in the format of [x1_begin, x2_begin...x1_end, x2_end,...], + // where xi_begin the number of pixels added at the beginning of axis i + // and xi_end, the number of pixels added at the end of axis i. int64_t pad_h = pads[0]; int64_t pad_w = pads.size() >= 4 ? pads[1] : 0; int64_t pad_d = pads.size() == 6 ? pads[2] : 0; diff --git a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu index 79d38319e2c20..b0ef2207ce7e1 100644 --- a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu +++ b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu @@ -23,9 +23,8 @@ limitations under the License. #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" - #include -//TODO:fix the warnings +// TODO:fix the warnings #ifdef _MSC_VER #pragma warning(disable : 4244) #endif @@ -266,11 +265,11 @@ Status NmsGpu(cudaStream_t stream, thread_block.y = kNmsBlockDim; thread_block.z = 1; NMSKernel<<>>(center_point_box, - d_sorted_boxes, - num_boxes, - iou_threshold, - bit_mask_len, - d_delete_mask); + d_sorted_boxes, + num_boxes, + iou_threshold, + bit_mask_len, + d_delete_mask); IAllocatorUniquePtr d_selected_boxes_ptr{allocator(num_boxes * sizeof(char))}; auto* d_selected_boxes = static_cast(d_selected_boxes_ptr.get()); @@ -351,7 +350,7 @@ Status NonMaxSuppressionImpl( static_cast(nullptr), // input indices static_cast(nullptr), // sorted indices num_boxes, // num items - 0, 8 * sizeof(float), // sort all bits + 0, 8 * sizeof(float), // sort all bits stream)); // allocate temporary memory diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu index 10053c630ab66..3f56d197d6bd3 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu +++ b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu @@ -1,18 +1,18 @@ /** -* Copyright (c) 2016-present, Facebook, Inc. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ /* Modifications Copyright (c) Microsoft. */ #include "roialign_impl.h" diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu index 9db3fb1251f08..51c80d272bb96 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu +++ b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu @@ -318,28 +318,28 @@ template Status reduce_sum( cudaStream_t stream, const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size) { return detail::call_reduce_matrix_columns( - stream, input, output, 1, size, buffer, buffer_size); + stream, input, output, 1, size, buffer, buffer_size); } template Status reduce_square_sum( cudaStream_t stream, const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size) { return detail::call_reduce_matrix_columns( - stream, input, output, 1, size, buffer, buffer_size); + stream, input, output, 1, size, buffer, buffer_size); } template Status reduce_l2_norm( cudaStream_t stream, const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size) { return detail::call_reduce_matrix_columns( - stream, input, output, 1, size, buffer, buffer_size); + stream, input, output, 1, size, buffer, buffer_size); } template Status reduce_mean( cudaStream_t stream, const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size) { return detail::call_reduce_matrix_columns( - stream, input, output, 1, size, buffer, buffer_size); + stream, input, output, 1, size, buffer, buffer_size); } #define INSTANTIATE_REDUCE_SUM(TIn, TOut) \ @@ -500,7 +500,7 @@ INSTANTIATE_REDUCE_MATRIX_ROWS(BFloat16); template Status reduce_matrix_columns(cudaStream_t stream, const TIn* input, TOut* output, int m, int n, void* buffer, size_t buffer_size) { return detail::call_reduce_matrix_columns( - stream, input, output, m, n, buffer, buffer_size); + stream, input, output, m, n, buffer, buffer_size); } #define INSTANTIATE_REDUCE_MATRIX_COLUMNS(T) \ diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu index 94c8036be6cdf..42ac4eaeca8a4 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu @@ -33,7 +33,7 @@ __global__ void _ReverseBySequenceKernel(const int32_t max_seq_length, template void ReverseBySequence(cudaStream_t stream, const int32_t max_seq_length, - const int32_t *seq_lengths, + const int32_t* seq_lengths, const int32_t batch_size, const int32_t input_or_hidden_size, const T* data, @@ -80,7 +80,7 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream, const T* data, T* reordered_data, const size_t N) { - // The cudnn Y output is organize like [Y1, YB1] [Y2, YB2] ... + // The cudnn Y output is organize like [Y1, YB1] [Y2, YB2] ... // need to reorganize it to [Y1, Y2, ...] [YB1, YB2, ...] int32_t seq_block_size = 2 * batch_size * hidden_size; fast_divmod div_seq_block(seq_block_size); @@ -123,7 +123,7 @@ __global__ void _MaskZeroSequences(const int32_t hidden_size, } } -template +template void MaskZeroSequences(cudaStream_t stream, const int32_t hidden_size, T* y_output_data, @@ -136,29 +136,29 @@ void MaskZeroSequences(cudaStream_t stream, hidden_size, y_output_data, y_h_output_data, y_c_output_data, zeor_seq_index_cache, (CUDA_LONG)N); } -#define SPECIALIZED_RNN_IMPL(T) \ - template void ReverseBySequence(cudaStream_t stream, \ - const int32_t max_seq_length, \ - const int32_t* seq_lengths, \ - const int32_t batch_size, \ - const int32_t hidden_size, \ - const T* data, \ - T* reversed_data, \ - const size_t N); \ - template void ReorderBidirectionalDataInSequence(cudaStream_t stream,\ - const int32_t seq_length, \ - const int32_t batch_size, \ - const int32_t hidden_size,\ - const T* data, \ - T* reordered_data, \ - const size_t N); \ -template void MaskZeroSequences(cudaStream_t stream, \ - const int32_t hidden_size, \ - T* y_output_data, \ - T* y_h_output_data, \ - T* y_c_output_data, \ - const int32_t* zeor_seq_index_cache, \ - const size_t N); +#define SPECIALIZED_RNN_IMPL(T) \ + template void ReverseBySequence(cudaStream_t stream, \ + const int32_t max_seq_length, \ + const int32_t* seq_lengths, \ + const int32_t batch_size, \ + const int32_t hidden_size, \ + const T* data, \ + T* reversed_data, \ + const size_t N); \ + template void ReorderBidirectionalDataInSequence(cudaStream_t stream, \ + const int32_t seq_length, \ + const int32_t batch_size, \ + const int32_t hidden_size, \ + const T* data, \ + T* reordered_data, \ + const size_t N); \ + template void MaskZeroSequences(cudaStream_t stream, \ + const int32_t hidden_size, \ + T* y_output_data, \ + T* y_h_output_data, \ + T* y_c_output_data, \ + const int32_t* zeor_seq_index_cache, \ + const size_t N); SPECIALIZED_RNN_IMPL(half) SPECIALIZED_RNN_IMPL(float) diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h index 240272923a3a7..de445e07f5f07 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h @@ -12,7 +12,7 @@ // NV_TODO: investigate cub support for half #pragma once - +#include #include "core/providers/cuda/cuda_common.h" // Generalize library calls to be use in template functions @@ -169,11 +169,32 @@ inline cublasStatus_t cublasGemmHelper( return cublasGemmEx(handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb, &h_b, C, CUDA_R_16BF, ldc, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT); } + +inline cublasStatus_t cublasGemmHelper( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, + int n, int k, const nv_bfloat16* alpha, const nv_bfloat16* A, int lda, + const nv_bfloat16* B, int ldb, const nv_bfloat16* beta, nv_bfloat16* C, int ldc, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { + float h_a = __bfloat162float(*alpha); + float h_b = __bfloat162float(*beta); + + // accumulating in FP32 + return cublasGemmEx(handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb, &h_b, C, + CUDA_R_16BF, ldc, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT); +} + #else inline cublasStatus_t cublasGemmHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const onnxruntime::BFloat16*, const onnxruntime::BFloat16*, int, const onnxruntime::BFloat16*, int, const onnxruntime::BFloat16*, - onnxruntime::BFloat16*, int, const cudaDeviceProp&, bool /*use_tf32*/) { + onnxruntime::BFloat16*, int, const cudaDeviceProp&, bool) { + return CUBLAS_STATUS_NOT_SUPPORTED; +} + +inline cublasStatus_t cublasGemmHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, + int, int, const nv_bfloat16*, const nv_bfloat16*, int, + const nv_bfloat16*, int, const nv_bfloat16*, + nv_bfloat16*, int, const cudaDeviceProp&, bool) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif diff --git a/onnxruntime/core/providers/cuda/tensor/cast_op.cu b/onnxruntime/core/providers/cuda/tensor/cast_op.cu index f2c2e6d7458f9..c98eabecdabab 100644 --- a/onnxruntime/core/providers/cuda/tensor/cast_op.cu +++ b/onnxruntime/core/providers/cuda/tensor/cast_op.cu @@ -166,8 +166,7 @@ Status CudaCastStd(cudaStream_t stream, const InT* input, OutT* output, size_t n input, output, static_cast(num_of_element), - CastStd() - ); + CastStd()); return Status::OK(); } @@ -197,8 +196,7 @@ Status CudaCastSat(cudaStream_t stream, const InT* input, OutT* output, size_t n output, static_cast(num_of_element), CastSat(), - saturate - ); + saturate); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/tensor/compress_impl.cu b/onnxruntime/core/providers/cuda/tensor/compress_impl.cu index 0c04e027ca1b9..54191ead1abec 100644 --- a/onnxruntime/core/providers/cuda/tensor/compress_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/compress_impl.cu @@ -6,7 +6,7 @@ #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" -//TODO:fix the warnings +// TODO:fix the warnings #ifdef _MSC_VER #pragma warning(disable : 4244) #endif @@ -29,8 +29,8 @@ struct CastToInt32 { }; cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data, int32_t* condition_cumulative_sum, int length, size_t& temp_storage_bytes) { - auto input_iter = thrust::make_transform_iterator(condition_data, CastToInt32()); - return cub::DeviceScan::InclusiveSum( + auto input_iter = thrust::make_transform_iterator(condition_data, CastToInt32()); + return cub::DeviceScan::InclusiveSum( nullptr, temp_storage_bytes, input_iter, condition_cumulative_sum, length, stream); } diff --git a/onnxruntime/core/providers/cuda/tensor/expand_impl.cu b/onnxruntime/core/providers/cuda/tensor/expand_impl.cu index cadc5dccc643d..e442afc8b6701 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/expand_impl.cu @@ -120,7 +120,7 @@ Status Expand2D( const int input_view_stride1) { #define EXPAND2D_ON(TYPE) \ case sizeof(TYPE): \ - ExpandKernel2D<<>>( \ + ExpandKernel2D<<>>( \ N, reinterpret_cast(input_data), reinterpret_cast(output_data), \ fdm_output_stride0, input_view_stride0, input_view_stride1); \ break @@ -165,7 +165,7 @@ Status ExpandImpl( #define EXPAND_ON(TYPE) \ case sizeof(TYPE): \ ExpandKernel \ - <<>>( \ + <<>>( \ rank, N_output, reinterpret_cast(input_data), reinterpret_cast(output_data), \ output_strides, input_strides); \ break diff --git a/onnxruntime/core/providers/cuda/tensor/eye_like_impl.cu b/onnxruntime/core/providers/cuda/tensor/eye_like_impl.cu index a3e588a2882f5..43b2871307528 100644 --- a/onnxruntime/core/providers/cuda/tensor/eye_like_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/eye_like_impl.cu @@ -35,13 +35,13 @@ void EyeLikeImpl( _EyeLikeKernel<<>>(offset, stripe, output_data, N); } -#define SPECIALIZED_IMPL(T) \ - template void EyeLikeImpl( \ - cudaStream_t stream, \ - size_t offset, \ - size_t stripe, \ - T* output_data, \ - size_t diag_count); +#define SPECIALIZED_IMPL(T) \ + template void EyeLikeImpl( \ + cudaStream_t stream, \ + size_t offset, \ + size_t stripe, \ + T* output_data, \ + size_t diag_count); SPECIALIZED_IMPL(int32_t) SPECIALIZED_IMPL(int64_t) diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu index ac380551f411f..81acb81be5025 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu @@ -269,22 +269,22 @@ template Status ScatterElementsImpl(cudaStream_t stream, const T* input_data, const TIndex* indices_data, const T* updates_data, T* output_data, const GatherScatterElementsArgs& args) { if (args.operation == GatherScatterElementsArgs::Operation::NONE) { - return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, - FuncAssignment()); + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncAssignment()); } else if (args.operation == GatherScatterElementsArgs::Operation::ADD) { - return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, - FuncAdd()); + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncAdd()); } else if (args.operation == GatherScatterElementsArgs::Operation::MUL) { - return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, - FuncMul()); + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMul()); } else if (args.operation == GatherScatterElementsArgs::Operation::MAX) { - return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, - FuncMax()); + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMax()); } else if (args.operation == GatherScatterElementsArgs::Operation::MIN) { - return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, - FuncMin()); + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMin()); } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported reduction operator."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported reduction operator."); } } diff --git a/onnxruntime/core/providers/cuda/tensor/gather_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_impl.cu index 2fb91e7ce5dbd..e0b5672b2a9d8 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_impl.cu @@ -63,7 +63,6 @@ void GatherImpl( size_t element_size, void* output_data, const size_t N) { - int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); switch (element_size) { diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu index b23da635bc83d..b5b4a84576bbe 100644 --- a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu @@ -21,7 +21,6 @@ __device__ T GsDenormalize(T n, int64_t length, bool align_corners) { return x; } - template __device__ T GsReflect(T x, float x_min, float x_max) { float fx = static_cast(x); @@ -57,8 +56,8 @@ __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_ auto PixelOffset = [bIdx, cIdx, C, H, W](int64_t x, int64_t y) -> int64_t { return Layout == LAYOUT_NCHW - ? (bIdx * C * H * W + cIdx * H * W + y * W + x) - : (bIdx * H * W * C + y * W * C + x * C + cIdx); + ? (bIdx * C * H * W + cIdx * H * W + y * W + x) + : (bIdx * H * W * C + y * W * C + x * C + cIdx); }; if (padding_mode == 0) { // zeros @@ -112,121 +111,120 @@ __global__ void _GridSampleKernel( const int64_t W_in, const int64_t H_out, const int64_t W_out, - T* output_data) -{ - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out); - // extract batch index, channel index, y index, x index for current thread - int BIdx, yIdx, xIdx, cIdx; - if constexpr (Layout == LAYOUT_NCHW) { - BIdx = idx / (C * H_out * W_out); - int tmpBCnt = BIdx * (C * H_out * W_out); + T* output_data) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out); + // extract batch index, channel index, y index, x index for current thread + int BIdx, yIdx, xIdx, cIdx; + if constexpr (Layout == LAYOUT_NCHW) { + BIdx = idx / (C * H_out * W_out); + int tmpBCnt = BIdx * (C * H_out * W_out); - cIdx = (idx - tmpBCnt) / (H_out * W_out); - int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out); + cIdx = (idx - tmpBCnt) / (H_out * W_out); + int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out); - yIdx = (idx - tmpCCnt) / W_out; - int tmpHCnt = tmpCCnt + yIdx * W_out; + yIdx = (idx - tmpCCnt) / W_out; + int tmpHCnt = tmpCCnt + yIdx * W_out; - xIdx = (idx - tmpHCnt); - } else { - static_assert(Layout == LAYOUT_NHWC, "Unsupported layout"); + xIdx = (idx - tmpHCnt); + } else { + static_assert(Layout == LAYOUT_NHWC, "Unsupported layout"); - BIdx = idx / (H_out * W_out * C); - int tmpBCnt = BIdx * (H_out * W_out * C); + BIdx = idx / (H_out * W_out * C); + int tmpBCnt = BIdx * (H_out * W_out * C); - yIdx = (idx - tmpBCnt) / (W_out * C); - int tmpHCnt = tmpBCnt + yIdx * (W_out * C); + yIdx = (idx - tmpBCnt) / (W_out * C); + int tmpHCnt = tmpBCnt + yIdx * (W_out * C); - xIdx = (idx - tmpHCnt) / C; - int tmpWCnt = tmpHCnt + xIdx * C; + xIdx = (idx - tmpHCnt) / C; + int tmpWCnt = tmpHCnt + xIdx * C; - cIdx = (idx - tmpWCnt); - } + cIdx = (idx - tmpWCnt); + } - int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx; - T grid_X = grid_data[grid_idx * 2 + 0]; - T grid_Y = grid_data[grid_idx * 2 + 1]; - int outIdx = idx; + int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx; + T grid_X = grid_data[grid_idx * 2 + 0]; + T grid_Y = grid_data[grid_idx * 2 + 1]; + int outIdx = idx; - T grid_x_imgSpace = GsDenormalize(grid_X, W_in, align_corners == 1); - T grid_y_imgSpace = GsDenormalize(grid_Y, H_in, align_corners == 1); - if (mode == 1) { //nearest - grid_x_imgSpace = nearbyint(grid_x_imgSpace); - grid_y_imgSpace = nearbyint(grid_y_imgSpace); - } - float x_min = -0.5f; - float x_max = W_in - 0.5f; - float y_min = -0.5f; - float y_max = H_in - 0.5f; + T grid_x_imgSpace = GsDenormalize(grid_X, W_in, align_corners == 1); + T grid_y_imgSpace = GsDenormalize(grid_Y, H_in, align_corners == 1); + if (mode == 1) { // nearest + grid_x_imgSpace = nearbyint(grid_x_imgSpace); + grid_y_imgSpace = nearbyint(grid_y_imgSpace); + } + float x_min = -0.5f; + float x_max = W_in - 0.5f; + float y_min = -0.5f; + float y_max = H_in - 0.5f; - if (align_corners) { - x_min = 0.0f; - x_max = W_in - 1.0; - y_min = 0.0f; - y_max = H_in - 1.0f; - } - float border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b - if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max || - grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound - if (padding_mode == 1) { // border - // Clamping must not be done here, see #10607 - // grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); - // grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); - } else if (padding_mode == 2) { // reflection - grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max); - grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max); - } + if (align_corners) { + x_min = 0.0f; + x_max = W_in - 1.0; + y_min = 0.0f; + y_max = H_in - 1.0f; + } + float border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b + if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max || + grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound + if (padding_mode == 1) { // border + // Clamping must not be done here, see #10607 + // grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); + // grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); + } else if (padding_mode == 2) { // reflection + grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max); + grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max); } + } - if (mode == 0) { // bilinear - int x1 = floor(grid_x_imgSpace); - int y1 = floor(grid_y_imgSpace); - int x2 = x1 + 1; - int y2 = y1 + 1; - T w_lt = 0.0f; - T w_rt = 0.0f; - T w_lb = 0.0f; - T w_rb = 0.0f; + if (mode == 0) { // bilinear + int x1 = floor(grid_x_imgSpace); + int y1 = floor(grid_y_imgSpace); + int x2 = x1 + 1; + int y2 = y1 + 1; + T w_lt = 0.0f; + T w_rt = 0.0f; + T w_lb = 0.0f; + T w_rb = 0.0f; - T w_r = grid_x_imgSpace - x1; - T w_l = 1.0f - w_r; - T w_b = grid_y_imgSpace - y1; - T w_t = 1.0f - w_b; + T w_r = grid_x_imgSpace - x1; + T w_l = 1.0f - w_r; + T w_b = grid_y_imgSpace - y1; + T w_t = 1.0f - w_b; - w_lt = w_t * w_l; - w_rt = w_t * w_r; - w_lb = w_b * w_l; - w_rb = w_b * w_r; + w_lt = w_t * w_l; + w_rt = w_t * w_r; + w_lb = w_b * w_l; + w_rb = w_b * w_r; - T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); - T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); - T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); - T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); - T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v; - output_data[outIdx] = interpoV; - return; - } - if (mode == 1) { // nearest - int x_n = grid_x_imgSpace; - int y_n = grid_y_imgSpace; - output_data[outIdx] = + T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); + T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); + T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); + T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); + T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v; + output_data[outIdx] = interpoV; + return; + } + if (mode == 1) { // nearest + int x_n = grid_x_imgSpace; + int y_n = grid_y_imgSpace; + output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); - return; - } - if (mode == 2) { // bicubic - int64_t x0 = static_cast(std::floor(grid_x_imgSpace)) - 1; // top-left corner of the bbox - int64_t y0 = static_cast(std::floor(grid_y_imgSpace)) - 1; - T p[4][4] = {}; // [H][W] - for (int64_t h = 0; h < 4; h++) { - for (int64_t w = 0; w < 4; w++) { - p[h][w] = + return; + } + if (mode == 2) { // bicubic + int64_t x0 = static_cast(std::floor(grid_x_imgSpace)) - 1; // top-left corner of the bbox + int64_t y0 = static_cast(std::floor(grid_y_imgSpace)) - 1; + T p[4][4] = {}; // [H][W] + for (int64_t h = 0; h < 4; h++) { + for (int64_t w = 0; w < 4; w++) { + p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); - } } - T dx = grid_x_imgSpace - x0 - 1; - T dy = grid_y_imgSpace - y0 - 1; - output_data[outIdx] = GsBicubicInterpolate(p, dx, dy); } + T dx = grid_x_imgSpace - x0 - 1; + T dy = grid_y_imgSpace - y0 - 1; + output_data[outIdx] = GsBicubicInterpolate(p, dx, dy); + } } template @@ -244,9 +242,9 @@ void GridSampleImpl( using Ch = Channels; int blocksPerGrid = static_cast( - ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock)); + ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock)); _GridSampleKernel<<>>( - input_data, grid_data, mode, padding_mode, align_corners, + input_data, grid_data, mode, padding_mode, align_corners, dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W], H_out, W_out, output_data); } diff --git a/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu b/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu index ce5a1ebf3faa5..9cf227ee2414e 100644 --- a/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu @@ -12,7 +12,7 @@ namespace cuda { static const int NONZERO_THREADS_PER_BLOCK = GridDim::maxThreadsPerBlock; -//TODO:check overflow +// TODO:check overflow int NonZeroCalcBlockCount(int64_t x_size) { return static_cast(CeilDiv(x_size, NONZERO_THREADS_PER_BLOCK)); } @@ -70,7 +70,6 @@ __global__ void NonZeroOutputPositionsKernel( } } - constexpr int MAX_DIMS = 16; template @@ -92,7 +91,7 @@ __global__ void UnRolledNonZeroOutputPositionsKernel( if (index < x_size && bool(x[index])) { int remain = (int)index, dim = 0; int rp = result_position; - #pragma unroll +#pragma unroll for (int axis = 0; axis < MAX_DIMS; ++axis) { if (axis == x_rank) { break; @@ -119,12 +118,12 @@ cudaError_t NonZeroOutputPositions( int num_blocks = NonZeroCalcBlockCount(x_size); if (x_rank > MAX_DIMS) { NonZeroOutputPositionsKernel<<>>( - x, x_size, x_rank, x_strides, - prefix_counts, nonzero_elements, results); + x, x_size, x_rank, x_strides, + prefix_counts, nonzero_elements, results); } else { UnRolledNonZeroOutputPositionsKernel<<>>( - x, x_size, x_rank, x_strides, - prefix_counts, nonzero_elements, results); + x, x_size, x_rank, x_strides, + prefix_counts, nonzero_elements, results); } return cudaSuccess; } diff --git a/onnxruntime/core/providers/cuda/tensor/onehot.cu b/onnxruntime/core/providers/cuda/tensor/onehot.cu index 1fb8dbe8b8805..92f86049f3da7 100644 --- a/onnxruntime/core/providers/cuda/tensor/onehot.cu +++ b/onnxruntime/core/providers/cuda/tensor/onehot.cu @@ -37,7 +37,7 @@ __global__ void _OneHotImpl( output_data[id] = (is_valid_range && adjusted_indice == in_type(depth_index)) ? on_value : off_value; } -template +template __global__ void _OneHotWithZeroOffValueImpl( const in_type* indices_data, const fast_divmod fdm_suffix, @@ -68,14 +68,14 @@ void OneHotImpl( int blocksPerGrid = (int)(ceil(static_cast(count) / GridDim::maxThreadsPerBlock)); CUDA_LONG N = static_cast(count); _OneHotImpl<<>>( - indices_data, - fdm_depth_suffix, - fdm_suffix, - depth_val, - on_value, - off_value, - output_data, - N); + indices_data, + fdm_depth_suffix, + fdm_suffix, + depth_val, + on_value, + off_value, + output_data, + N); } template @@ -90,47 +90,47 @@ void OneHotWithZeroOffValueImpl( int blocksPerGrid = (int)(ceil(static_cast(count) / GridDim::maxThreadsPerBlock)); CUDA_LONG N = static_cast(count); _OneHotWithZeroOffValueImpl<<>>( - indices_data, - fdm_suffix, - depth_val, - on_value, - output_data, - N); + indices_data, + fdm_suffix, + depth_val, + on_value, + output_data, + N); } #define SPECIALIZED_OneHotImpl(in_type, out_type) \ template void OneHotImpl( \ - cudaStream_t stream, \ - const in_type* indices_data, \ - const fast_divmod fdm_depth_suffix, \ - const fast_divmod fdm_suffix, \ - const int64_t depth_val, \ - const out_type on_value, \ - const out_type off_value, \ - out_type* output_data, \ - size_t count); + cudaStream_t stream, \ + const in_type* indices_data, \ + const fast_divmod fdm_depth_suffix, \ + const fast_divmod fdm_suffix, \ + const int64_t depth_val, \ + const out_type on_value, \ + const out_type off_value, \ + out_type* output_data, \ + size_t count); SPECIALIZED_OneHotImpl(int64_t, int64_t) -SPECIALIZED_OneHotImpl(int64_t, float) -SPECIALIZED_OneHotImpl(int32_t, float) -SPECIALIZED_OneHotImpl(int64_t, half) -SPECIALIZED_OneHotImpl(int32_t, half) + SPECIALIZED_OneHotImpl(int64_t, float) + SPECIALIZED_OneHotImpl(int32_t, float) + SPECIALIZED_OneHotImpl(int64_t, half) + SPECIALIZED_OneHotImpl(int32_t, half) #define SPECIALIZED_OneHotWithZeroOffValueImpl(in_type, out_type) \ template void OneHotWithZeroOffValueImpl( \ - cudaStream_t stream, \ - const in_type* indices_data, \ - const fast_divmod fdm_suffix, \ - const int64_t depth_val, \ - const out_type on_value, \ - out_type* output_data, \ - size_t count); - -SPECIALIZED_OneHotWithZeroOffValueImpl(int64_t, int64_t) -SPECIALIZED_OneHotWithZeroOffValueImpl(int64_t, float) -SPECIALIZED_OneHotWithZeroOffValueImpl(int32_t, float) -SPECIALIZED_OneHotWithZeroOffValueImpl(int64_t, half) -SPECIALIZED_OneHotWithZeroOffValueImpl(int32_t, half) + cudaStream_t stream, \ + const in_type* indices_data, \ + const fast_divmod fdm_suffix, \ + const int64_t depth_val, \ + const out_type on_value, \ + out_type* output_data, \ + size_t count); + + SPECIALIZED_OneHotWithZeroOffValueImpl(int64_t, int64_t) + SPECIALIZED_OneHotWithZeroOffValueImpl(int64_t, float) + SPECIALIZED_OneHotWithZeroOffValueImpl(int32_t, float) + SPECIALIZED_OneHotWithZeroOffValueImpl(int64_t, half) + SPECIALIZED_OneHotWithZeroOffValueImpl(int32_t, half) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu index 19b148d9193c9..f1012bf7296fb 100644 --- a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu +++ b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu @@ -419,13 +419,13 @@ Status CudaQuantizeLinearStdInt4(cudaStream_t stream, const InT* input, OutT* ou int blocksPerGrid = static_cast(CeilDiv(num_of_element, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); QuantizeLinearKernelStdInt4 - <<>>( - input, - output, - scale, - zero_point, - static_cast(num_of_element), - RoundStdInt4()); + <<>>( + input, + output, + scale, + zero_point, + static_cast(num_of_element), + RoundStdInt4()); return Status::OK(); } @@ -611,15 +611,15 @@ __global__ void DequantizeLinearKernelAxisStdInt4(const InT* input, OutT* output #pragma unroll for (; i + 1 < NumElementsPerThread && id + 1 < num_element; i += 2, id += step) { - scale_id0 = (id / n_same_scale) % n_scales; - scale_id1 = ((id + 1) / n_same_scale) % n_scales; - - v0 = ExtractInt4FromByte(input[id >> 1], 0); - v1 = ExtractInt4FromByte(input[id >> 1], 1); - zp0 = zero_point_ptr == nullptr ? 0 : ExtractInt4FromByte(zero_point_ptr[scale_id0 >> 1], scale_id0 & 1); - zp1 = zero_point_ptr == nullptr ? 0 : ExtractInt4FromByte(zero_point_ptr[scale_id1 >> 1], scale_id1 & 1); - output[id] = static_cast(v0 - zp0) * scale_ptr[scale_id0]; - output[id + 1] = static_cast(v1 - zp1) * scale_ptr[scale_id1]; + scale_id0 = (id / n_same_scale) % n_scales; + scale_id1 = ((id + 1) / n_same_scale) % n_scales; + + v0 = ExtractInt4FromByte(input[id >> 1], 0); + v1 = ExtractInt4FromByte(input[id >> 1], 1); + zp0 = zero_point_ptr == nullptr ? 0 : ExtractInt4FromByte(zero_point_ptr[scale_id0 >> 1], scale_id0 & 1); + zp1 = zero_point_ptr == nullptr ? 0 : ExtractInt4FromByte(zero_point_ptr[scale_id1 >> 1], scale_id1 & 1); + output[id] = static_cast(v0 - zp0) * scale_ptr[scale_id0]; + output[id + 1] = static_cast(v1 - zp1) * scale_ptr[scale_id1]; } if (i < NumElementsPerThread && id < num_element) { diff --git a/onnxruntime/core/providers/cuda/tensor/reverse_sequence_impl.cu b/onnxruntime/core/providers/cuda/tensor/reverse_sequence_impl.cu index 4d37b6a206ece..f9b60b54f5226 100644 --- a/onnxruntime/core/providers/cuda/tensor/reverse_sequence_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/reverse_sequence_impl.cu @@ -80,7 +80,7 @@ cudaError_t ReverseSequenceCudaImpl( #define InstantiateReverseSequenceImpl(T) \ template cudaError_t ReverseSequenceCudaImpl( \ - cudaStream_t stream, \ + cudaStream_t stream, \ const T* x_data, \ const int64_t* seq_len_data, \ T* y_data, \ diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.cu b/onnxruntime/core/providers/cuda/tensor/split_impl.cu index e2f42e4d5855c..6c2cdfe029a08 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.cu @@ -196,7 +196,7 @@ Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t void* output_data2, const gsl::span& input_shape) { CUDA_LONG outer_size = 1; for (size_t i = 0; i < input_shape.size() - 1; ++i) { - outer_size *= static_cast(input_shape[i]); + outer_size *= static_cast(input_shape[i]); } CUDA_LONG inner_size_in_byte = static_cast(input_shape[input_shape.size() - 1] * element_size); @@ -234,16 +234,16 @@ Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t } switch (VEC_SIZE) { -#define CASE_ELEMENT_TYPE(type) \ - _Split3InnerKernel<<>>( \ - size0_in_byte, \ - size1_in_byte, \ - size2_in_byte, \ - input_data, \ - output_data0, \ - output_data1, \ - output_data2, \ - inner_size_in_byte) +#define CASE_ELEMENT_TYPE(type) \ + _Split3InnerKernel<<>>( \ + size0_in_byte, \ + size1_in_byte, \ + size2_in_byte, \ + input_data, \ + output_data0, \ + output_data1, \ + output_data2, \ + inner_size_in_byte) case 16: CASE_ELEMENT_TYPE(int4); break; diff --git a/onnxruntime/core/providers/cuda/tensor/where_impl.cu b/onnxruntime/core/providers/cuda/tensor/where_impl.cu index d7909454e922c..d55dbd352cd9c 100644 --- a/onnxruntime/core/providers/cuda/tensor/where_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/where_impl.cu @@ -115,19 +115,19 @@ __global__ void _TenaryElementWiseSimple( } } -#define HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \ - case Y_INDEX_TYPE: { \ - _TenaryElementWiseSimple \ - <<>>(cond_data, \ - x_data, \ - y_data, \ - output_data, \ - N); \ +#define HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \ + case Y_INDEX_TYPE: { \ + _TenaryElementWiseSimple \ + <<>>(cond_data, \ + x_data, \ + y_data, \ + output_data, \ + N); \ } break #define HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE_VAL) \ @@ -146,24 +146,24 @@ __global__ void _TenaryElementWiseSimple( } \ } break -#define HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \ - case Y_INDEX_TYPE: { \ - _TenaryElementWise \ +#define HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \ + case Y_INDEX_TYPE: { \ + _TenaryElementWise \ <<>>(output_rank_or_simple_broadcast, \ - cond_padded_strides, \ - cond_data, \ - x_padded_strides, \ - x_data, \ - y_padded_strides, \ - y_data, \ - fdm_output_strides, \ - output_data, \ - N); \ + cond_padded_strides, \ + cond_data, \ + x_padded_strides, \ + x_data, \ + y_padded_strides, \ + y_data, \ + fdm_output_strides, \ + output_data, \ + N); \ } break #define HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE_VAL) \ diff --git a/onnxruntime/core/providers/cuda/triton_kernel.cu b/onnxruntime/core/providers/cuda/triton_kernel.cu index ee93be3e67c27..d8bc4cd9abe60 100644 --- a/onnxruntime/core/providers/cuda/triton_kernel.cu +++ b/onnxruntime/core/providers/cuda/triton_kernel.cu @@ -130,8 +130,6 @@ void LoadOrtTritonKernel() { std::call_once(load_ort_triton_kernel_flag, TryToLoadKernel); } - - #ifdef USE_TRITON_KERNEL Status LaunchTritonKernel(cudaStream_t stream, size_t idx, int grid0, int grid1, int grid2, void* args, size_t args_size) { @@ -195,7 +193,6 @@ Status LaunchTritonKernel(cudaStream_t /*stream*/, size_t /*idx*/, int /*grid0*/ } #endif - const TritonKernelMetaData* GetOrtTritonKernelMetadata(size_t idx) { if (idx >= ort_triton_kernel_metadata.size()) { return nullptr; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 2be286440bcf4..102ea5378640c 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -849,6 +849,23 @@ std::unique_ptr JsExecutionProvider::GetExtern return std::make_unique(); } +std::optional JsExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain, + std::string_view node_op_type, + DataLayout target_data_layout) const { + if (target_data_layout != DataLayout::NHWC) { + return std::nullopt; + } + + // TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed + if (node_domain == kOnnxDomain && node_op_type == "Resize") { + // leave Resize as-is pending bugfix for NHWC implementation. this means the node will remain in the ONNX domain + // with the original input layout. + return false; + } + + return std::nullopt; +} + JsExecutionProvider::~JsExecutionProvider() { } diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index c87303209c689..7847285782e1e 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -54,6 +54,10 @@ class JsExecutionProvider : public IExecutionProvider { DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } + std::optional ShouldConvertDataLayoutForOp(std::string_view node_domain, + std::string_view node_op_type, + DataLayout target_data_layout) const override; + FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } // JSEP disallow concurrent run because actual implementation (eg. WebGPU backend) relies on global states to work, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 716c2c39cd837..1b5882062361c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -337,7 +337,7 @@ std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, onnxruntime::CUDA); }, info_.device_id); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, onnxruntime::CUDA_PINNED); + return std::make_unique(device_id, onnxruntime::CUDA_PINNED); }, 0); return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc index d9d3507a4687b..f90bf24ef4975 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc @@ -80,7 +80,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi info.ep_context_file_path = session_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); // If embed mode is not specified, default to 1 if dump_ep_context_model is true, otherwise 0 - const auto embed_mode = std::stoi(session_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "-1")); + auto embed_mode = std::stoi(session_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "-1")); if (embed_mode == -1) { if (info.dump_ep_context_model) embed_mode = 1; diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 3d36fe5e8ff31..529f6ce824033 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -188,6 +188,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateMatMulOpBuilder("MatMul", *this); } + { + CreateMeanOpBuilder("Mean", *this); + } + { CreateLSTMOpBuilder("LSTM", *this); } diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index b9b3c34467855..4fa9ec6cc0fe1 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -126,6 +126,8 @@ void CreateLSTMOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateMeanOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + void CreateUDOBuilder(const std::string& op_type, const std::string& op_package, OpBuilderRegistrations& op_registrations); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc index b367f58faf139..fa5e95727e651 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc @@ -12,7 +12,7 @@ namespace qnn { /** * An ONNX MatMul can be translated to either a QNN MatMul or a QNN FullyConnected. - * ONNX's MatMul suports inputs of rank 1, but neither QNN's MatMul nor FullyConnected support two rank 1 inputs. + * ONNX's MatMul supports inputs of rank 1, but neither QNN's MatMul nor FullyConnected support two rank 1 inputs. * So, we need to add Reshape Ops if necessary. * In two cases, FullyConnected (input_1's shape is [n, k]) is used instead of MatMul without extra Transpose Op: * 1. input_1 is a rank 2 initializer. diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/mean_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/mean_op_builder.cc new file mode 100644 index 0000000000000..07e73350fde5f --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/mean_op_builder.cc @@ -0,0 +1,115 @@ +// Copyright (c) Qualcomm. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_utils.h" + +namespace onnxruntime { +namespace qnn { + +class MeanOpBuilder : public BaseOpBuilder { + public: + MeanOpBuilder() : BaseOpBuilder("MeanOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MeanOpBuilder); + + protected: + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, + std::vector&& input_names, const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; +}; + +Status MeanOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, + std::vector&& input_names, const logging::Logger& logger, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(logger); + ORT_UNUSED_PARAMETER(do_op_validation); + + const auto& inputs = node_unit.Inputs(); + const auto& output = node_unit.Outputs()[0]; + + if (inputs.size() < 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Mean op requires at least two inputs."); + } + + // Combine Add Operations together + std::string sum_output = input_names[0]; + TensorInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input_info)); + + for (size_t i = 1; i < input_names.size(); ++i) { + // Get output shape + std::vector output_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output.node_arg, output_shape), "Failed to get output shape."); + std::vector unpackage_data(sizeof(float)); + + const std::string add_output = sum_output + "_ort_qnn_ep_add_" + std::to_string(i); + QnnTensorWrapper add_tensor(add_output, QNN_TENSOR_TYPE_NATIVE, input_info.qnn_data_type, + QnnQuantParamsWrapper(), std::move(output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(add_tensor)), + "Failed to add Add tensor wrapper."); + const std::string add_op_name = "Mean_Add_" + std::to_string(i); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(add_op_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_ELEMENT_WISE_ADD, + {sum_output, input_names[i]}, + {add_output}, + {}, + do_op_validation), + "Create Qnn Node for Add Op Failed"); + + sum_output = add_output; + } + + // Number of inputs to divide with + float divisor = static_cast(inputs.size()); + std::vector scalar_shape = {1}; + std::vector divisor_data(sizeof(float)); + memcpy(divisor_data.data(), &divisor, sizeof(float)); + + const std::string divisor_name = sum_output + "_ort_qnn_ep_divisor"; + + QnnTensorWrapper divisor_tensor(divisor_name, QNN_TENSOR_TYPE_STATIC, input_info.qnn_data_type, + QnnQuantParamsWrapper(), std::move(scalar_shape), std::move(divisor_data)); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(divisor_tensor)), "AddTensorWrapper Failed"); + + // Final step - Division + const std::string output_name = output.node_arg.Name(); + std::vector output_shape; + TensorInfo output_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(output, output_info)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output.node_arg, output_shape), "Failed to get output shape."); + Qnn_TensorType_t output_tensor_type = qnn_model_wrapper.IsGraphOutput(output.node_arg.Name()) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper output_tensor(output_name, output_tensor_type, output_info.qnn_data_type, + output_info.quant_param.Copy(), std::move(output_shape)); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), + "Failed to add output tensor wrapper."); + std::vector div_inputs = {sum_output, divisor_name}; + const std::string div_node_name = output_name + "_div"; + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(div_node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_ELEMENT_WISE_DIVIDE, + {sum_output, divisor_name}, + {output_name}, + {}, + do_op_validation), + "Failed to create Mean_Div node."); + + return Status::OK(); +} + +void CreateMeanOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index f932858eb2fd9..86b684f8c6ebd 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -1,10 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include +#include +#include + #include "core/providers/qnn/builder/opbuilder/base_op_builder.h" -#include "core/providers/qnn/builder/qnn_utils.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_utils.h" namespace onnxruntime { namespace qnn { @@ -16,7 +21,7 @@ class PoolOpBuilder : public BaseOpBuilder { Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; protected: Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, @@ -33,9 +38,11 @@ class PoolOpBuilder : public BaseOpBuilder { QnnQuantParamsWrapper& quant_param) const override ORT_MUST_USE_RESULT; private: - Status SetCommonPoolParams(const NodeAttrHelper& node_helper, std::vector& filter_size, - std::vector& pad_amount, std::vector& stride, - int32_t& ceil_mode, + Status SetCommonPoolParams(const NodeAttrHelper& node_helper, + std::vector& filter_size, + std::vector& stride, + std::vector& pad_amount, + int32_t& rounding_mode, std::vector&& input_shape, std::vector&& output_shape) const; }; @@ -50,48 +57,27 @@ Status PoolOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const { ORT_UNUSED_PARAMETER(logger); - if (node_unit.Domain() == kMSInternalNHWCDomain) { // Use QNN validation API if layout is NHWC. - return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); - } - const auto& inputs = node_unit.Inputs(); ORT_RETURN_IF_ERROR(DataTypeCheckForCpuBackend(qnn_model_wrapper, inputs[0].node_arg.Type())); std::vector input_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape"); - bool is1d = (input_shape.size() == 3); - if (!is1d && input_shape.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Pool only supports rank 3 or 4!"); - } - - NodeAttrHelper node_helper(node_unit); - - if (is1d) { - auto kernel_shape = node_helper.Get("kernel_shape", std::vector{}); - ORT_RETURN_IF_NOT(kernel_shape.size() == 1, "QNN Pool1D: kernel_shape must have length 1!"); - - auto pads = node_helper.Get("pads", std::vector{}); - ORT_RETURN_IF_NOT(pads.size() == 2, "QNN Pool1D: pads must have length 2!"); - - auto strides = node_helper.Get("strides", std::vector{}); - ORT_RETURN_IF_NOT(strides.empty() || strides.size() == 1, "QNN Pool1D: strides must have length 1 or omitted!"); + size_t rank = input_shape.size(); + ORT_RETURN_IF_NOT(rank == 3 || rank == 4 || rank == 5, "QNN Pool only supports rank 3, 4, or 5!"); - auto dilations = node_helper.Get("dilations", std::vector{1}); - ORT_RETURN_IF_NOT(dilations.size() == 1, "QNN Pool1D: dilations must have length 1 or omitted!"); - } else { - auto dilations = node_helper.Get("dilations", std::vector{1, 1}); - ORT_RETURN_IF_NOT(dilations.size() == 2, "QNN Pool2D: dilations must have length 2 or omitted!"); - } + // ONNX MaxPool may have two outputs. + ORT_RETURN_IF(node_unit.Outputs().size() > 1, "QNN Pool only supports 1 output!"); - if (node_unit.Outputs().size() > 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN only support 1 output!"); - } + NodeAttrHelper node_helper(node_unit); + auto dilations = node_helper.Get("dilations", std::vector(rank - 2, 1)); + ORT_RETURN_IF_NOT(dilations == std::vector(rank - 2, 1), "QNN Pool only supports dilations 1!"); const std::string& op_type = node_unit.OpType(); - // Onnx GlobalMaxPool doesn't have any attributes - if (op_type == "GlobalMaxPool") { - return Status::OK(); + const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + + if (rank == 5 && is_npu_backend) { + ORT_RETURN_IF(op_type == "MaxPool" || op_type == "GlobalMaxPool", "QNN NPU does not support PoolMax3d!"); } if (op_type == "MaxPool" || op_type == "AveragePool") { @@ -100,6 +86,10 @@ Status PoolOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, "QNN Pool operators do not support 'auto_pad' value: ", auto_pad.c_str()); } + if (node_unit.Domain() == kMSInternalNHWCDomain) { // Use QNN validation API if layout is NHWC. + return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); + } + return Status::OK(); } @@ -135,34 +125,47 @@ static std::vector AmendOutputShapeForRank3Pool( Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, std::vector& filter_size, - std::vector& pad_amount, std::vector& strides, - int32_t& ceil_mode, + std::vector& stride, + std::vector& pad_amount, + int32_t& rounding_mode, std::vector&& input_shape, std::vector&& output_shape) const { + size_t rank = input_shape.size(); + + // Param: filter_size. { - auto raw_filter_size = node_helper.Get("kernel_shape", std::vector{1, 1}); + auto raw_filter_size = node_helper.Get("kernel_shape", std::vector(rank - 2, 1)); if (raw_filter_size.size() == 1) { filter_size = {1, raw_filter_size[0]}; } else { filter_size = raw_filter_size; } } - ORT_RETURN_IF_NOT(filter_size.size() == 2, - "QNN only support kernel_shape with shape[2]."); + // Param: stride. + { + auto raw_stride = node_helper.Get("strides", std::vector(rank - 2, 1)); + if (raw_stride.size() == 1) { + stride = {1, raw_stride[0]}; + } else { + stride = raw_stride; + } + } + + // Param: dilations (NOT SUPPORTED by QNN). + std::vector dilations; { - auto raw_strides = node_helper.Get("strides", std::vector{1, 1}); - if (raw_strides.size() == 1) { - strides = {1, raw_strides[0]}; + auto raw_dilations = node_helper.Get("dilations", std::vector(rank - 2, 1)); + if (raw_dilations.size() == 1) { + dilations = {1, raw_dilations[0]}; } else { - strides = raw_strides; + dilations = raw_dilations; } } - ORT_RETURN_IF_NOT(strides.size() == 2, - "QNN only support strides with shape[2]."); + // Param: pad_amount. { - auto raw_pad_amount = node_helper.Get("pads", std::vector{0, 0, 0, 0}); + auto raw_pad_amount = node_helper.Get("pads", std::vector((rank - 2) * 2, 0)); if (raw_pad_amount.size() == 2) { pad_amount = {0, raw_pad_amount[0], 0, raw_pad_amount[1]}; } else { @@ -171,48 +174,36 @@ Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, } auto auto_pad = node_helper.Get("auto_pad", std::string("NOTSET")); - ORT_RETURN_IF(auto_pad != "NOTSET" && auto_pad != "SAME_LOWER" && auto_pad != "SAME_UPPER" && auto_pad != "VALID", - "QNN Pool operators do not support 'auto_pad' value: ", auto_pad.c_str()); - if (auto_pad.compare("NOTSET") != 0) { - std::vector dilations; - auto raw_dilations = node_helper.Get("dilations", std::vector{1, 1}); - if (raw_dilations.size() == 1) { - dilations = {1, raw_dilations[0]}; - } else { - dilations = raw_dilations; - } - - // Max Pool rank 3 input if (output_shape.size() == 3) { - // Calculate MaxPool output for rank-4 when input is rank 3 + // Calculate rank-4 output shape for rank-3 input. output_shape = AmendOutputShapeForRank3Pool(input_shape, filter_size, - strides, + stride, pad_amount); } - auto total_pads_0 = (output_shape[1] - 1) * strides[0] + (filter_size[0] - 1) * dilations[0] + 1 - input_shape[1]; - auto total_pads_1 = (output_shape[2] - 1) * strides[1] + (filter_size[1] - 1) * dilations[1] + 1 - input_shape[2]; - if (auto_pad.compare("SAME_LOWER") != 0) { - pad_amount[0] = total_pads_0 / 2; - pad_amount[1] = total_pads_1 / 2; - pad_amount[2] = total_pads_0 - pad_amount[0]; - pad_amount[3] = total_pads_1 - pad_amount[1]; - } else if (auto_pad.compare("SAME_UPPER") != 0) { - pad_amount[2] = total_pads_0 / 2; - pad_amount[3] = total_pads_1 / 2; - pad_amount[0] = total_pads_0 - pad_amount[2]; - pad_amount[1] = total_pads_1 - pad_amount[3]; + + for (size_t axis = 0; axis < rank - 2; ++axis) { + uint32_t total_pads = (output_shape[axis + 1] - 1) * stride[axis] + + (filter_size[axis] - 1) * dilations[axis] + 1 - input_shape[axis + 1]; + if (auto_pad.compare("SAME_LOWER") == 0) { + pad_amount[axis + rank - 2] = total_pads / 2; + pad_amount[axis] = total_pads - pad_amount[axis + rank - 2]; + } else if (auto_pad.compare("SAME_UPPER") == 0) { + pad_amount[axis] = total_pads / 2; + pad_amount[axis + rank - 2] = total_pads - pad_amount[axis]; + } } } - ORT_RETURN_IF_NOT(pad_amount.size() == 4, "QNN only support pads with shape[2, 2]."); ReArranagePads(pad_amount); - ceil_mode = node_helper.Get("ceil_mode", ceil_mode); + // Param: rounding_mode. + rounding_mode = node_helper.Get("ceil_mode", rounding_mode); + return Status::OK(); -} // namespace qnn +} -void SetPoolParam(const NodeUnit& node_unit, +bool SetPoolParam(const NodeUnit& node_unit, const std::string& param_name, std::vector&& parm_shape, std::vector&& parm_data, @@ -224,7 +215,7 @@ void SetPoolParam(const NodeUnit& node_unit, std::move(parm_shape), std::move(parm_data)); param_tensor_names.push_back(qnn_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_param)); + return qnn_model_wrapper.AddParamWrapper(std::move(qnn_param)); } Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, @@ -238,6 +229,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector input_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape"); + // Reshape 3D input to 4D if necessary. const auto& reshape_input = node_unit.Inputs()[0]; TensorInfo reshape_input_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info)); @@ -275,32 +267,88 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra input_shape = {input_shape[0], 1, input_shape[1], input_shape[2]}; } - ORT_RETURN_IF_NOT(input_shape.size() == 4, "Input should have 4 dims NCHW or 3 dims for 1D pooling."); - // Default value for GlobalAveragePool - // Pool use filter & stride with shape [filter_height, filter_width] - // With layout transformer, the input has shape [batch, height, width, channel], - std::vector filter_size(input_shape.begin() + 1, input_shape.begin() + 3); + const std::string& op_type = node_unit.OpType(); + const size_t rank = input_shape.size(); + + // QNN constants for construction. + std::string qnn_op_type; + std::string param_filter_size; + std::string param_stride; + std::string param_pad_amount; + std::string param_count_pad_for_edges; + std::string param_rounding_mode; + if (rank == 4) { + if (op_type == "MaxPool" || op_type == "GlobalMaxPool") { + qnn_op_type = QNN_OP_POOL_MAX_2D; + param_filter_size = QNN_OP_POOL_MAX_2D_PARAM_FILTER_SIZE; + param_stride = QNN_OP_POOL_MAX_2D_PARAM_STRIDE; + param_pad_amount = QNN_OP_POOL_MAX_2D_PARAM_PAD_AMOUNT; + param_rounding_mode = QNN_OP_POOL_MAX_2D_PARAM_ROUNDING_MODE; + } else { + qnn_op_type = QNN_OP_POOL_AVG_2D; + param_filter_size = QNN_OP_POOL_AVG_2D_PARAM_FILTER_SIZE; + param_stride = QNN_OP_POOL_AVG_2D_PARAM_STRIDE; + param_pad_amount = QNN_OP_POOL_AVG_2D_PARAM_PAD_AMOUNT; + param_count_pad_for_edges = QNN_OP_POOL_AVG_2D_PARAM_COUNT_PAD_FOR_EDGES; + param_rounding_mode = QNN_OP_POOL_AVG_2D_PARAM_ROUNDING_MODE; + } + } else { + if (op_type == "MaxPool" || op_type == "GlobalMaxPool") { + qnn_op_type = QNN_OP_POOL_MAX_3D; + param_filter_size = QNN_OP_POOL_MAX_3D_PARAM_FILTER_SIZE; + param_stride = QNN_OP_POOL_MAX_3D_PARAM_STRIDE; + param_pad_amount = QNN_OP_POOL_MAX_3D_PARAM_PAD_AMOUNT; + param_rounding_mode = QNN_OP_POOL_MAX_3D_PARAM_ROUNDING_MODE; + } else { + qnn_op_type = QNN_OP_POOL_AVG_3D; + param_filter_size = QNN_OP_POOL_AVG_3D_PARAM_FILTER_SIZE; + param_stride = QNN_OP_POOL_AVG_3D_PARAM_STRIDE; + param_pad_amount = QNN_OP_POOL_AVG_3D_PARAM_PAD_AMOUNT; + param_count_pad_for_edges = QNN_OP_POOL_AVG_3D_PARAM_COUNT_PAD_FOR_EDGES; + param_rounding_mode = QNN_OP_POOL_AVG_3D_PARAM_ROUNDING_MODE; + } + } + + // Default parameters for GlobalMaxPool/GlobalAveragePool with filter and stride in spatial shapes. + // Note that input is already in spatial-first layout. + std::vector filter_size(input_shape.begin() + 1, input_shape.begin() + rank - 1); + std::vector filter_size_dim{static_cast(rank - 2)}; std::vector stride(filter_size); - std::vector filter_size_dim{2}; - std::vector stride_dim{2}; - std::vector pad_amount{0, 0, 0, 0}; - std::vector pad_amount_dim{2, 2}; - int32_t ceil_mode = 0; + std::vector stride_dim{static_cast(rank - 2)}; + std::vector pad_amount((rank - 2) * 2, 0); + std::vector pad_amount_dim{static_cast(rank - 2), 2}; + int32_t rounding_mode = 0; std::vector param_tensor_names; - const std::string& op_type = node_unit.OpType(); if (op_type == "GlobalMaxPool") { - // set default params for Qnn PoolMax2D - SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_FILTER_SIZE, std::move(filter_size_dim), std::move(filter_size), param_tensor_names, qnn_model_wrapper); - SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_PAD_AMOUNT, std::move(pad_amount_dim), std::move(pad_amount), param_tensor_names, qnn_model_wrapper); - SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_STRIDE, std::move(stride_dim), std::move(stride), param_tensor_names, qnn_model_wrapper); + ORT_RETURN_IF_NOT(SetPoolParam(node_unit, + param_filter_size, + std::move(filter_size_dim), + std::move(filter_size), + param_tensor_names, + qnn_model_wrapper), + "Failed to add param filter_size."); + ORT_RETURN_IF_NOT(SetPoolParam(node_unit, + param_stride, + std::move(stride_dim), + std::move(stride), + param_tensor_names, + qnn_model_wrapper), + "Failed to add param stride."); + ORT_RETURN_IF_NOT(SetPoolParam(node_unit, + param_pad_amount, + std::move(pad_amount_dim), + std::move(pad_amount), + param_tensor_names, + qnn_model_wrapper), + "Failed to add param pad_amount."); ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names), logger, do_op_validation, - GetQnnOpType(op_type))); + qnn_op_type)); return Status::OK(); } @@ -309,40 +357,57 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector output_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(outputs[0].node_arg, output_shape), "Cannot get shape"); - ORT_RETURN_IF_ERROR(SetCommonPoolParams(node_helper, filter_size, pad_amount, stride, ceil_mode, - std::move(input_shape), std::move(output_shape))); + ORT_RETURN_IF_ERROR(SetCommonPoolParams(node_helper, + filter_size, + stride, + pad_amount, + rounding_mode, + std::move(input_shape), + std::move(output_shape))); } + // Calculate rank-4 output shape for rank-3 input. std::vector onnx_in_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, onnx_in_shape), "Cannot get shape"); - // Reshaped input rank-4 for MaxPool if (onnx_in_shape.size() == 3) { - onnx_in_shape = {onnx_in_shape[0], - 1, - onnx_in_shape[1], - onnx_in_shape[2]}; + onnx_in_shape = {onnx_in_shape[0], 1, onnx_in_shape[1], onnx_in_shape[2]}; } - - // Calculate MaxPool output for rank-4 when input is rank 3 - auto pooled_shape = AmendOutputShapeForRank3Pool(onnx_in_shape, - filter_size, - stride, - pad_amount); - - SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_FILTER_SIZE, std::move(filter_size_dim), std::move(filter_size), param_tensor_names, qnn_model_wrapper); - SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_PAD_AMOUNT, std::move(pad_amount_dim), std::move(pad_amount), param_tensor_names, qnn_model_wrapper); - SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_STRIDE, std::move(stride_dim), std::move(stride), param_tensor_names, qnn_model_wrapper); - - if (0 != ceil_mode) { - Qnn_Scalar_t rounding_mode_param = QNN_SCALAR_INIT; - rounding_mode_param.dataType = QNN_DATATYPE_UINT_32; - rounding_mode_param.int32Value = ceil_mode; + auto pooled_shape = AmendOutputShapeForRank3Pool(onnx_in_shape, filter_size, stride, pad_amount); + + // Construct param wrappers. + ORT_RETURN_IF_NOT(SetPoolParam(node_unit, + param_filter_size, + std::move(filter_size_dim), + std::move(filter_size), + param_tensor_names, + qnn_model_wrapper), + "Failed to add param filter_size."); + ORT_RETURN_IF_NOT(SetPoolParam(node_unit, + param_stride, + std::move(stride_dim), + std::move(stride), + param_tensor_names, + qnn_model_wrapper), + "Failed to add param stride."); + ORT_RETURN_IF_NOT(SetPoolParam(node_unit, + param_pad_amount, + std::move(pad_amount_dim), + std::move(pad_amount), + param_tensor_names, + qnn_model_wrapper), + "Failed to add param pad_amount."); + + if (0 != rounding_mode) { + Qnn_Scalar_t scalar_param = QNN_SCALAR_INIT; + scalar_param.dataType = QNN_DATATYPE_UINT_32; + scalar_param.int32Value = rounding_mode; QnnParamWrapper rounding_mode_param_wrapper(node_unit.Index(), node_unit.Name(), - QNN_OP_POOL_MAX_2D_PARAM_ROUNDING_MODE, - rounding_mode_param); + param_rounding_mode, + scalar_param); param_tensor_names.push_back(rounding_mode_param_wrapper.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(rounding_mode_param_wrapper)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddParamWrapper(std::move(rounding_mode_param_wrapper)), + "Failed to add param rounding_mode."); } if (op_type == "GlobalAveragePool") { Qnn_Scalar_t scalar_param = QNN_SCALAR_INIT; @@ -350,29 +415,32 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra scalar_param.bool8Value = 1; QnnParamWrapper count_pad_for_edges_param(node_unit.Index(), node_unit.Name(), - QNN_OP_POOL_AVG_2D_PARAM_COUNT_PAD_FOR_EDGES, + param_count_pad_for_edges, scalar_param); param_tensor_names.push_back(count_pad_for_edges_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(count_pad_for_edges_param)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddParamWrapper(std::move(count_pad_for_edges_param)), + "Failed to add param count_pad_for_edges."); } else if (op_type == "AveragePool") { Qnn_Scalar_t scalar_param = QNN_SCALAR_INIT; scalar_param.dataType = QNN_DATATYPE_BOOL_8; scalar_param.bool8Value = static_cast(node_helper.Get("count_include_pad", static_cast(0)) != 0); QnnParamWrapper count_pad_for_edges_param(node_unit.Index(), node_unit.Name(), - QNN_OP_POOL_AVG_2D_PARAM_COUNT_PAD_FOR_EDGES, + param_count_pad_for_edges, scalar_param); param_tensor_names.push_back(count_pad_for_edges_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(count_pad_for_edges_param)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddParamWrapper(std::move(count_pad_for_edges_param)), + "Failed to add param count_include_pad."); } if (!needs_reshape) { - ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, + node_unit, std::move(input_names), std::move(param_tensor_names), logger, do_op_validation, - GetQnnOpType(op_type))); + qnn_op_type)); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 70850dc7162c8..236447cc95c3d 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -1066,6 +1066,21 @@ DataLayout QNNExecutionProvider::GetPreferredLayout() const { return DataLayout::NHWC; } +std::optional QNNExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain, + std::string_view node_op_type, + DataLayout target_data_layout) const { + if (target_data_layout != DataLayout::NHWC) { + return std::nullopt; + } + + if (node_domain == kOnnxDomain && node_op_type == "Upsample") { + // Upsample is translated to QNN's Resize, which requires the NHWC layout for processing. + return true; + } + + return std::nullopt; +} + Status QNNExecutionProvider::CreateComputeFunc(std::vector& node_compute_funcs, const logging::Logger& logger) { NodeComputeInfo compute_info; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 7115708d42d8c..06f9726ae96cf 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -43,6 +43,10 @@ class QNNExecutionProvider : public IExecutionProvider { DataLayout GetPreferredLayout() const override; + std::optional ShouldConvertDataLayoutForOp(std::string_view node_domain, + std::string_view node_op_type, + DataLayout target_data_layout) const override; + const InlinedVector GetEpContextNodes() const override; Status OnRunStart(const onnxruntime::RunOptions& run_options) override; diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index b2f289448b013..8a5f83f636824 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -133,12 +133,12 @@ struct QnnEpFactory : OrtEpFactory { // Returns the name for the EP. Each unique factory configuration must have a unique name. // Ex: a factory that supports NPU should have a different than a factory that supports GPU. - static const char* GetNameImpl(const OrtEpFactory* this_ptr) { + static const char* GetNameImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->ep_name.c_str(); } - static const char* GetVendorImpl(const OrtEpFactory* this_ptr) { + static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->vendor.c_str(); } @@ -154,7 +154,7 @@ struct QnnEpFactory : OrtEpFactory { size_t num_devices, OrtEpDevice** ep_devices, size_t max_ep_devices, - size_t* p_num_ep_devices) { + size_t* p_num_ep_devices) noexcept { size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); @@ -180,11 +180,11 @@ struct QnnEpFactory : OrtEpFactory { _In_ size_t /*num_devices*/, _In_ const OrtSessionOptions* /*session_options*/, _In_ const OrtLogger* /*logger*/, - _Out_ OrtEp** /*ep*/) { + _Out_ OrtEp** /*ep*/) noexcept { return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "QNN EP factory does not support this method."); } - static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) { + static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) noexcept { // no-op as we never create an EP here. } diff --git a/onnxruntime/core/providers/rocm/atomic/common.cuh b/onnxruntime/core/providers/rocm/atomic/common.cuh deleted file mode 100644 index b5d01b91c70ed..0000000000000 --- a/onnxruntime/core/providers/rocm/atomic/common.cuh +++ /dev/null @@ -1,362 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include -#include "core/framework/float16.h" - -typedef __half half; - -namespace onnxruntime { -namespace rocm { - -__device__ __forceinline__ void atomic_add(float *address, float value) { - atomicAdd(address, value); -} - -__device__ __forceinline__ void atomic_add(double *address, double value) { - atomicAdd(address, value); -} - -// -// ref: https://github.com/pytorch/pytorch/blob/master/aten/src/THC/THCAtomics.cuh -// -__device__ __forceinline__ void atomic_add(half *address, half value) { - unsigned int* base_address = (unsigned int*)((char*)address - ((size_t)address & 2)); - unsigned int old = *base_address; - unsigned int assumed; - unsigned short x; - - do { - assumed = old; - x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - x = __half_as_short(__float2half(__half2float(*reinterpret_cast(&x)) + __half2float(value))); - old = (size_t)address & 2 ? (old & 0xffff) | (x << 16) : (old & 0xffff0000) | x; - old = atomicCAS(base_address, assumed, old); - } while (assumed != old); -} - -__device__ __forceinline__ void atomic_add(BFloat16* address, BFloat16 value) { - unsigned int* base_address = - reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); - unsigned int old = *base_address; - unsigned int assumed; - BFloat16 bsum; - do { - assumed = old; - bsum.val = reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); - bsum = bsum + value; - old = reinterpret_cast(address) & 2 ? (old & 0xffff) | (bsum.val << 16) : (old & 0xffff0000) | bsum.val; - old = atomicCAS(base_address, assumed, old); - } while (assumed != old); -} - -// This function is added to speed up atomic add for half/bf16 type on CUDA. For ROCM, use default implementation. -template -__device__ __forceinline__ void AtomicAdd(T *start_addr, size_t index, const size_t numel, T value) { - ORT_UNUSED_PARAMETER(numel); - atomic_add(start_addr + index, value); -} - -// Disable default template instantiation. -// For every type T, we need to define a specialization -// to select the right type for calling atomicCAS. -template -class AtomicCasType; - -template<> -class AtomicCasType { - public: - using type = unsigned short int; - static const unsigned int mask = 0xffu; -}; - -template<> -class AtomicCasType { - public: - using type = unsigned short int; - static const unsigned int mask = 0xffffu; -}; - -template<> -class AtomicCasType { - public: - using type = unsigned int; - static const unsigned int mask = 0xffffffffu; -}; - -template<> -class AtomicCasType { - public: - using type = unsigned long long int; - static const unsigned int mask = 0xffffffffu; -}; - -template<> -class AtomicCasType { - public: - using type = int; - static const unsigned int mask = 0xffffffffu; -}; - -template<> -class AtomicCasType { - public: - using type = unsigned long long int; - static const unsigned int mask = 0xffffffffu; -}; - -// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. -// -// This function compute 8-bit atomic binary operation using 32-bit atomicCAS. -// It accumulate `val` into the `address` using the `func`. -// The accumulation is atomic (i.e., thread-safe). -// -// E.g., Assume ValueType is -// int8_t -// and BinaryFunc is -// struct AddFunc { -// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const { -// return a + b; -// } -// This function becomes atomic_add for int8_t. -template -__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { - // Assert to ensure the following bit-wise manipulation is correct. - static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, - "ValueType must be 1-byte, 2-byte or 4-byte large."); - // Number of bytes to the lower 4-byte aligned address. - // If the current address is b1010"10", then offset = b10 = 2, - // which means the current address is 2 bytes away from - // the lower 4-byte aligned address b1010"00". - size_t offset = (size_t)address & 3; - // Find an new 4-byte aligned address `address_as_ui` lower than - // or equal to `address`. Lower than `address` so that the actual - // int8_t byte is in the 4-byte word that we load. - // - // This address has the following properties: - // 1. It is 4-byte aligned. - // 2. It is lower than or equal to `address`. - // 3. De-referencing this address may return - // a uint32_t value that contains the same int8_t - // value indicated by `address`. - // - // E.g., - // address = b101010 - // offset = b101010 & b000011 = b10 = 2 - // (char*)address - offset => (char*)b101010 - b000010 => b1010"00", - // which is (32-bit aligned). - uint32_t * address_as_ui = (uint32_t*)((char*)address - offset); - uint32_t old = *address_as_ui; - // E.g., offset = 2. - // address_as_ui is an address 2 bytes lower than `address`. - // - // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... - // ^ ^ ^ - // | | | - // | address <--- offset * 8 (bit)-----> address_as_ui - // | ^ - // | | - // ------------------------- *address_as_ui ----------------------- - // - // This visualization shows - // 1. the 32-bit word at address_as_ui. - // 2. the gap between address_as_ui and address. - // 3. *address_as_ui contains the int8_t value at `address`. - uint32_t shift = offset * 8; - uint32_t old_byte; - uint32_t newval; - uint32_t assumed; - do { - assumed = old; - // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so - // we want to select the 3rd byte (byte 2 below) from the word. - // - // Journey of a 32-bit value: - // - // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... - // - // | - // | old >> offset * 8, where offset = 2. - // | Effectively, push lower two bytes - // | out of the word. - // V - // - // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 ..... - // - // | apply bit-wise AND, - // | & 0xff (i.e., & b11111111), - // | so that we only keep - // | the byte of interest. - // | Otherwise, overflow may - // | happen when casting this - // | 32-bit value to int8_t. - // V - // - // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... - old_byte = (old >> shift) & AtomicCasType::mask; - // Compute new int8_t value and store it to newrawvalue. - // Journey of a 32-bit value (cont'd): - // - // newrawvalue - // ... new byte 2 ... - auto newrawvalue = func(val, reinterpret_cast(old_byte)); - // Put the new int8_t value back to 32-bit word. - // Also ensure that bits not occupied by the int8_t value are 0s. - // - // Journey of a 32-bit value (cont'd): - // - // reinterpret_cast(newrawvalue) - // random values | random values | random values | ... new byte 2 ... - // - // reinterpret_cast(newrawvalue) & AtomicCasType::mask - // 00000000 | 00000000 | 00000000 | ... new byte 2 ... - newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask; - // Journey of a 32-bit value (cont'd): - // - // old - // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... - // - // 0x000000ff - // 00000000 | 00000000 | 00000000 | 11111111 - // - // 0x000000ff << shift - // 00000000 | 11111111 | 00000000 | 00000000 - // - // ~(0x000000ff << shift) - // 11111111 | 00000000 | 11111111 | 11111111 - // - // old & ~(0x000000ff << shift) - // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 ..... - // - // newval << shift - // 00000000 | ... new byte 2 ... | 00000000 | 00000000 - // - // (old & ~(0x000000ff << shift)) | (newval << shift) - // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... - newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift); - old = atomicCAS(address_as_ui, assumed, newval); - } while (assumed != old); -} - -// It accumulates `val` into the `address` using the `func`. -// This function is thread-safe (i.e., atomic). -template -__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { - ValueType observed = *address, assumed, new_value; - using CasType = typename AtomicCasType::type; - static_assert(sizeof(ValueType) == sizeof(CasType), - "ValueType and CasType must have the same size for calling atomicCAS."); - auto address_as_cas_type = reinterpret_cast(address); - do { - // Record the value used to compute new value. - assumed = observed; - - // Compute expected new value. - new_value = func(observed, val); - - // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. - // 4 - // 8 - auto observed_as_cas_type = *reinterpret_cast(&observed); - auto new_value_as_cas_type = *reinterpret_cast(&new_value); - - // Call atomicCAS as if the 2-byte type variables are all unsigned short int. - // 4 unsigned int (or int) - // 8 unsigned long long int - auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); - - // Cast the freshly observed value in memory back to the TwoByteType. - observed = *reinterpret_cast(&cas_observed_as_cas_type); - - // Two cases: - // 1. compare-and-swap success - // a. `address` holds `new_value` - // b. `observed` becomes the new value after the assignment. - // Thus, the following `observed != new_value` is false, - // and the loop terminates. - // 2. compare-and-swap fails - // a. `address` holds a value different from `observed`, thus, - // the `new_value` is stale. - // b. `observed` becomes the fresh value observed in `address`. - // Thus, the following (observed != new_value) is true, - // and the loop continues. In the next iteration, the - // `new_value` is computed again using the fresh `observed`. - } while (observed != assumed); -} - -struct AddFunc { - template - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -struct MulFunc { - template - __device__ __forceinline__ T operator()(T a, T b) const { - return a * b; - } -}; - -struct MaxFunc { - template - __device__ __forceinline__ T operator()(T a, T b) const { - return b > a ? b : a; - } -}; - -struct MinFunc { - template - __device__ __forceinline__ T operator()(T a, T b) const { - return b < a ? b : a; - } -}; - -__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { - atomic_byte_func_with_unit32_cas(address, value, AddFunc()); -} -__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { - atomic_byte_func_with_unit32_cas(address, value, MulFunc()); -} -__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { - atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); -} -__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { - atomic_byte_func_with_unit32_cas(address, value, MinFunc()); -} - -__device__ __forceinline__ void atomic_mul(half* address, half value) { - atomic_byte_func_with_unit32_cas(address, value, MulFunc()); -} -__device__ __forceinline__ void atomic_max(half* address, half value) { - atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); -} -__device__ __forceinline__ void atomic_min(half* address, half value) { - atomic_byte_func_with_unit32_cas(address, value, MinFunc()); -} - -__device__ __forceinline__ void atomic_mul(float* address, float value) { - atomic_binary_func(address, value, MulFunc()); -} -__device__ __forceinline__ void atomic_max(float* address, float value) { - atomic_binary_func(address, value, MaxFunc()); -} -__device__ __forceinline__ void atomic_min(float* address, float value) { - atomic_binary_func(address, value, MinFunc()); -} - -__device__ __forceinline__ void atomic_mul(double* address, double value) { - atomic_binary_func(address, value, MulFunc()); -} -__device__ __forceinline__ void atomic_max(double* address, double value) { - atomic_binary_func(address, value, MaxFunc()); -} -__device__ __forceinline__ void atomic_min(double* address, double value) { - atomic_binary_func(address, value, MinFunc()); -} - - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/backward_guard.cc b/onnxruntime/core/providers/rocm/backward_guard.cc deleted file mode 100644 index 1695da092bef0..0000000000000 --- a/onnxruntime/core/providers/rocm/backward_guard.cc +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#include "core/providers/rocm/backward_guard.h" - -namespace onnxruntime { - -thread_local bool BackwardPassGuard::is_backward_pass_; - -BackwardPassGuard::BackwardPassGuard() { - is_backward_pass_ = true; -} - -BackwardPassGuard::~BackwardPassGuard() { - is_backward_pass_ = false; -} - -bool BackwardPassGuard::is_backward_pass() { - return is_backward_pass_; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/backward_guard.h b/onnxruntime/core/providers/rocm/backward_guard.h deleted file mode 100644 index 12f3c0b27410f..0000000000000 --- a/onnxruntime/core/providers/rocm/backward_guard.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#pragma once - -namespace onnxruntime { - -struct BackwardPassGuard { - BackwardPassGuard(); - ~BackwardPassGuard(); - static bool is_backward_pass(); - - private: - static thread_local bool is_backward_pass_; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/composable_kernel_common.h b/onnxruntime/core/providers/rocm/composable_kernel_common.h deleted file mode 100644 index 6f504995e40a3..0000000000000 --- a/onnxruntime/core/providers/rocm/composable_kernel_common.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#ifdef USE_COMPOSABLE_KERNEL -#include "ck/utility/data_type.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#endif - -#include "core/framework/float8.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/gemm_common.h" - -namespace onnxruntime { -namespace rocm { - -#ifdef USE_COMPOSABLE_KERNEL -template -struct CKBlasOpAdaptor { - using type = std::conditional_t; -}; - -template -struct CKDataTypeAdaptor { - using type = T; -}; - -template <> -struct CKDataTypeAdaptor { - using type = ck::half_t; -}; - -template <> -struct CKDataTypeAdaptor { - using type = ck::half_t; -}; - -template <> -struct CKDataTypeAdaptor { - using type = ck::bhalf16_t; -}; - -#if !defined(DISABLE_FLOAT8_TYPES) -template <> -struct CKDataTypeAdaptor { - using type = ck::f8_t; -}; - -template <> -struct CKDataTypeAdaptor { - using type = ck::f8_t; -}; -#endif - -#endif - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh deleted file mode 100644 index b8fe875ba54b7..0000000000000 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ /dev/null @@ -1,608 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include -#include -#include -#include -#include -#include -#include -//#include -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/rocm_call.h" - -namespace onnxruntime { -namespace rocm { - -/// Arithmetic for BFloat16 - -__device__ __forceinline__ BFloat16 operator+(const BFloat16& a, const BFloat16& b) { - return static_cast(a) + static_cast(b); -} - -__device__ __forceinline__ BFloat16 operator-(const BFloat16& a, const BFloat16& b) { - return static_cast(a) - static_cast(b); -} - -__device__ __forceinline__ BFloat16 operator*(const BFloat16& a, const BFloat16& b) { - return static_cast(a) * static_cast(b); -} - -__device__ __forceinline__ BFloat16 operator/(const BFloat16& a, const BFloat16& b) { - return static_cast(a) / static_cast(b); -} - -__device__ __forceinline__ BFloat16 operator-(const BFloat16& a) { return -static_cast(a); } - -__device__ __forceinline__ BFloat16& operator+=(BFloat16& a, const BFloat16& b) { - a = a + b; - return a; -} - -__device__ __forceinline__ BFloat16& operator-=(BFloat16& a, const BFloat16& b) { - a = a - b; - return a; -} - -__device__ __forceinline__ BFloat16& operator*=(BFloat16& a, const BFloat16& b) { - a = a * b; - return a; -} - -__device__ __forceinline__ BFloat16& operator/=(BFloat16& a, const BFloat16& b) { - a = a / b; - return a; -} - -/// Arithmetic with floats - -__device__ __forceinline__ float operator+(BFloat16 a, float b) { return static_cast(a) + b; } -__device__ __forceinline__ float operator-(BFloat16 a, float b) { return static_cast(a) - b; } -__device__ __forceinline__ float operator*(BFloat16 a, float b) { return static_cast(a) * b; } -__device__ __forceinline__ float operator/(BFloat16 a, float b) { return static_cast(a) / b; } - -__device__ __forceinline__ float operator+(float a, BFloat16 b) { return a + static_cast(b); } -__device__ __forceinline__ float operator-(float a, BFloat16 b) { return a - static_cast(b); } -__device__ __forceinline__ float operator*(float a, BFloat16 b) { return a * static_cast(b); } -__device__ __forceinline__ float operator/(float a, BFloat16 b) { return a / static_cast(b); } - -__device__ __forceinline__ float& operator+=(float& a, const BFloat16& b) { return a += static_cast(b); } -__device__ __forceinline__ float& operator-=(float& a, const BFloat16& b) { return a -= static_cast(b); } -__device__ __forceinline__ float& operator*=(float& a, const BFloat16& b) { return a *= static_cast(b); } -__device__ __forceinline__ float& operator/=(float& a, const BFloat16& b) { return a /= static_cast(b); } - -/// Arithmetic with doubles - -__device__ __forceinline__ double operator+(BFloat16 a, double b) { return static_cast(a) + b; } -__device__ __forceinline__ double operator-(BFloat16 a, double b) { return static_cast(a) - b; } -__device__ __forceinline__ double operator*(BFloat16 a, double b) { return static_cast(a) * b; } -__device__ __forceinline__ double operator/(BFloat16 a, double b) { return static_cast(a) / b; } - -__device__ __forceinline__ double operator+(double a, BFloat16 b) { return a + static_cast(b); } -__device__ __forceinline__ double operator-(double a, BFloat16 b) { return a - static_cast(b); } -__device__ __forceinline__ double operator*(double a, BFloat16 b) { return a * static_cast(b); } -__device__ __forceinline__ double operator/(double a, BFloat16 b) { return a / static_cast(b); } - -// Overloading < and > operators - -__device__ __forceinline__ bool operator==(BFloat16& lhs, BFloat16& rhs) { return float(lhs) == float(rhs); } -__device__ __forceinline__ bool operator!=(BFloat16& lhs, BFloat16& rhs) { return float(lhs) != float(rhs); } -__device__ __forceinline__ bool operator>(BFloat16& lhs, BFloat16& rhs) { return float(lhs) > float(rhs); } -__device__ __forceinline__ bool operator<(BFloat16& lhs, BFloat16& rhs) { return float(lhs) < float(rhs); } - -template -__device__ __inline__ T _Ceil(T a); - -template <> -__device__ __inline__ float _Ceil(float a) { return ceilf(a); } - -template <> -__device__ __inline__ double _Ceil(double a) { return ceil(a); } - -template <> -__device__ __inline__ half _Ceil(half a) { return half(ceilf((float)a)); } - -template -__device__ __inline__ T _Floor(T a); - -template <> -__device__ __inline__ float _Floor(float a) { return floorf(a); } - -template <> -__device__ __inline__ double _Floor(double a) { return floor(a); } - -template <> -__device__ __inline__ half _Floor(half a) { return half(floorf((float)a)); } - -template -__device__ __inline__ T _Sqrt(T a); - -template <> -__device__ __inline__ float _Sqrt(float a) { return sqrtf(a); } - -template <> -__device__ __inline__ double _Sqrt(double a) { return sqrt(a); } - -template <> -__device__ __inline__ half _Sqrt(half a) { return half(sqrtf((float)a)); } - -template -__device__ __inline__ T _Erf(T a); - -template <> -__device__ __inline__ float _Erf(float a) { return erff(a); } - -template <> -__device__ __inline__ double _Erf(double a) { return erf(a); } - -template <> -__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); } - -template <> -__device__ __inline__ BFloat16 _Erf(BFloat16 a) { return BFloat16(erff((float)a)); } - -template -__device__ __inline__ T _Round(T a); - -template <> -__device__ __inline__ float _Round(float a) { return rintf(a); } - -template <> -__device__ __inline__ double _Round(double a) { return rint(a); } - -template <> -__device__ __inline__ half _Round(half a) { - return hrint(a); -} - -template -__device__ __inline__ T _Cos(T a); - -template <> -__device__ __inline__ float _Cos(float a) { return cosf(a); } - -template <> -__device__ __inline__ double _Cos(double a) { return cos(a); } - -template <> -__device__ __inline__ half _Cos(half a) { - return hcos(a); -} - -template -__device__ __inline__ T _Sin(T a); - -template <> -__device__ __inline__ float _Sin(float a) { return sinf(a); } - -template <> -__device__ __inline__ double _Sin(double a) { return sin(a); } - -template <> -__device__ __inline__ half _Sin(half a) { - return hsin(a); -} - -template -__device__ __inline__ T _Exp(T a); - -template <> -__device__ __inline__ float _Exp(float a) { return expf(a); } - -template <> -__device__ __inline__ double _Exp(double a) { return exp(a); } - -template <> -__device__ __inline__ half _Exp(half a) { return half(expf((float)a)); } - -template -__device__ __inline__ T _Log(T a); - -template <> -__device__ __inline__ float _Log(float a) { return logf(a); } - -template <> -__device__ __inline__ double _Log(double a) { return log(a); } - -template <> -__device__ __inline__ half _Log(half a) { return half(logf((float)a)); } - -template -__device__ __inline T _Tanh(T a); - -template <> -__device__ __inline__ float _Tanh(float a) { return tanhf(a); } - -template <> -__device__ __inline__ double _Tanh(double a) { return tanh(a); } - -template <> -__device__ __inline__ half _Tanh(half a) { return half(tanhf((float)a)); } - -template <> -__device__ __inline__ half2 _Tanh(half2 a) { - float2 tmp = (__half22float2(a)); - tmp.x = tanhf(tmp.x); - tmp.y = tanhf(tmp.y); - return __float22half2_rn(tmp); -} - -// Capture permutations of int32/64/float/double -template -__device__ __inline__ T _Pow(T a, T1 b) { - return static_cast(pow(static_cast(a), static_cast(b))); -} - -template <> -__device__ __inline__ float _Pow(float a, float b) { return powf(a, b); } - -template <> -__device__ __inline__ double _Pow(double a, double b) { return pow(a, b); } - -template <> -__device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (float)b)); } - -#define ISNAN_BFLOAT16(v__) static_cast(*reinterpret_cast(&v__) & ~BFloat16::kSignMask) \ - > BFloat16::kPositiveInfinityBits - -// Note that there is no consistent canonical NaN for FP16 and BF16; -// HIP uses 0x7FFF for HIPRT_NAN_BF16, but ONNX Runtime uses 0x7FC1. -// (see BFloat16Impl::kPositiveQNaNBits). -#define NAN_BFLOAT16 BFloat16::FromBits((uint16_t)0x7FFFU) - -template -__device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; } - -template <> -__device__ __inline__ float _Min(float a, float b) { - return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); -} - -template <> -__device__ __inline__ double _Min(double a, double b) { - return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); -} - -template <> -__device__ __inline__ half _Min(half a, half b) { - return __hmin_nan(a, b); -} - -template <> -__device__ __inline__ BFloat16 _Min(BFloat16 a, BFloat16 b) { - return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a < b ? a : b); -} - -template -__device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } - -template <> -__device__ __inline__ float _Max(float a, float b) { - return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); -} - -template <> -__device__ __inline__ double _Max(double a, double b) { - return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); -} - -template <> -__device__ __inline__ half _Max(half a, half b) { - return __hmax_nan(a, b); -} - -template <> -__device__ __inline__ BFloat16 _Max(BFloat16 a, BFloat16 b) { - return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a > b ? a : b); -} - -#undef ISNAN_BFLOAT16 -#undef NAN_BFLOAT16 - -template -__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } - -template -__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; } - -template -__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); } - -template -__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed()); } - -template <> -__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); } - -template -__device__ __inline__ T _Normcdf(T a); - -template <> -__device__ __inline__ float _Normcdf(float a) { return normcdff(a); } - -template <> -__device__ __inline__ double _Normcdf(double a) { return normcdf(a); } - -template <> -__device__ __inline__ half _Normcdf(half a) { return half(normcdff((float)a)); } - -template <> -__device__ __inline__ BFloat16 _Sqrt(BFloat16 a) { return sqrtf(static_cast(a)); } - -template <> -__device__ __inline__ BFloat16 _Exp(BFloat16 a) { return expf(static_cast(a)); } - -template <> -__device__ __inline__ BFloat16 _Log(BFloat16 a) { return logf(static_cast(a)); } - -template <> -__device__ __inline__ BFloat16 _Tanh(BFloat16 a) { return tanhf(static_cast(a)); } - -template <> -__device__ __inline__ BFloat16 _Normcdf(BFloat16 a) { return normcdff(static_cast(a)); } - -template -__device__ __inline__ T _Gelu(T a) { - return a * _Normcdf(a); -} - -template <> -__device__ __inline__ half _Gelu(half a) { - const half kHalf = half(0.5); - const half kOne = half(1.0); - const half kAlpha = half(M_SQRT1_2); - return a * kHalf * (kOne + _Erf(kAlpha * a)); -} - -template -__device__ __inline__ T _Mod(T a, T b) { - T r = a % b; - T zero = T(0); - if ((r > zero && b < zero) || (r < zero && b > zero)) { - r += b; - } - return r; -} - -template -__device__ __inline__ T _Fmod(T a, T b) { - return a % b; -} - -template <> -__device__ __inline__ float _Fmod(float a, float b) { - return fmodf(a, b); -} - -template <> -__device__ __inline__ double _Fmod(double a, double b) { - return fmod(a, b); -} - -template <> -__device__ __inline__ half _Fmod(half a, half b) { - return fmodf((float)a, (float)b); -} - -template <> -__device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) { - return fmodf((float)a, (float)b); -} - -namespace isinf_details { -template -struct IsInfTyped { - static __device__ __inline__ bool IsInf(T a) { - // cast is needed because on non MS compilers, - // because there isinf() returns int - // and we want to avoid stupid warnings - return static_cast(isinf(a)); - } - static __device__ __inline__ bool IsInfPos(T a) { - return a == std::numeric_limits::infinity(); - } - static __device__ __inline__ bool IsInfNeg(T a) { - return a == -std::numeric_limits::infinity(); - } -}; - -template <> -struct IsInfTyped { - static __device__ __inline__ bool IsInf(half a) { - return MLFloat16::kPositiveInfinityBits == - static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask); - } - static __device__ __inline__ bool IsInfPos(half a) { - return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); - } - static __device__ __inline__ bool IsInfNeg(half a) { - return MLFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); - } -}; - -template <> -struct IsInfTyped { - static __device__ __inline__ bool IsInf(BFloat16 a) { - return BFloat16::kPositiveInfinityBits == - static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask); - } - static __device__ __inline__ bool IsInfPos(BFloat16 a) { - return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); - } - static __device__ __inline__ bool IsInfNeg(BFloat16 a) { - return BFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); - } -}; - -#if !defined(DISABLE_FLOAT8_TYPES) - -template -struct ReturnFalse { - constexpr static bool __device__ __inline__ IsInf(T) { return false; } - constexpr static bool __device__ __inline__ IsInfPos(T) { return false; } - constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; } -}; - -template <> -struct IsInfTyped : ReturnFalse {}; - -template <> -struct IsInfTyped : ReturnFalse {}; - -template <> -struct IsInfTyped { - static __device__ __inline__ bool IsInf(Float8E5M2 a) { - return a.val == 0b01111100 || a.val == 0b11111100; - } - static __device__ __inline__ bool IsInfPos(Float8E5M2 a) { - return a.val == 0b01111100; - } - static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) { - return a.val == 0b11111100; - } -}; - -template <> -struct IsInfTyped : ReturnFalse {}; - -#endif -} // namespace isinf_details - -template -struct _IsInf { - __device__ __inline__ bool operator()(T a) const { - if constexpr (detect_positive && detect_negative) { - return isinf_details::IsInfTyped::IsInf(a); - } else if constexpr (detect_positive) { - return isinf_details::IsInfTyped::IsInfPos(a); - } else if constexpr (detect_negative) { - return isinf_details::IsInfTyped::IsInfNeg(a); - } else { - return false; - } - } -}; - -// float and double -template -struct _IsNan { - __device__ __inline__ bool operator()(T a) const { - return isnan(a); - } -}; - -template <> -struct _IsNan { - __device__ __inline__ bool operator()(half a) const { - return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) - > MLFloat16::kPositiveInfinityBits; - } -}; - -template <> -struct _IsNan { - __device__ __inline__ bool operator()(BFloat16 a) const { - return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) - > BFloat16::kPositiveInfinityBits; - } -}; - -#if !defined(DISABLE_FLOAT8_TYPES) - -template<> -struct _IsNan { - __device__ __inline__ bool operator()(Float8E4M3FN a) const { - return (*reinterpret_cast(&a) & 0x7f) == 0x7f; - } -}; - -template<> -struct _IsNan { - __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const { - return *reinterpret_cast(&a) == 0x80; - } -}; - -template<> -struct _IsNan { - __device__ __inline__ bool operator()(Float8E5M2 a) const { - uint8_t c = *reinterpret_cast(&a); - return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); - } -}; - -template<> -struct _IsNan { - __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const { - return *reinterpret_cast(&a) == 0x80; - } -}; - -#endif - -// We would like to use 64-bit integer to support large matrices. However, ROCM seems to support only 32-bit integer -// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. -#ifndef HIP_LONG -#define HIP_LONG int32_t -#endif - -template -inline __host__ __device__ INT CeilDiv(INT a, INT2 b) // ceil(a/b) -{ - return (INT)(((size_t)a + (size_t)b - 1) / (size_t)b); // these size_t casts are necessary since b may be INT_MAX (for maxGridSize[]) -} - -struct GridDim { - enum : HIP_LONG { - maxThreadsPerBlock = 256, // max threads per block - maxElementsPerThread = 4, // max element processed per thread - }; -}; - -// aligned vector generates vectorized load/store on ROCM -template -struct alignas(sizeof(T) * vec_size) aligned_vector { - T val[vec_size]; -}; - -#define CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N) \ - HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x; \ - if (id >= N) \ - return; - -// HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels. -#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) - -// WARP related definitions and functions -constexpr int GPU_WARP_SIZE = warpSize; -inline int GPU_WARP_SIZE_HOST = warpSizeDynamic(); - -template -__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = GPU_WARP_SIZE, unsigned int mask = 0xffffffff) { - ORT_UNUSED_PARAMETER(mask); - return __shfl(value, srcLane, width); -} - -template -__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = GPU_WARP_SIZE, unsigned int mask = 0xffffffff) { - ORT_UNUSED_PARAMETER(mask); - return __shfl_xor(value, laneMask, width); -} - -template -__device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = GPU_WARP_SIZE, unsigned int mask = 0xffffffff) { - ORT_UNUSED_PARAMETER(mask); - return __shfl_up(value, delta, width); -} - -template -__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = GPU_WARP_SIZE, unsigned int mask = 0xffffffff) { - ORT_UNUSED_PARAMETER(mask); - return __shfl_down(value, delta, width); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/fpgeneric.cu b/onnxruntime/core/providers/rocm/fpgeneric.cu deleted file mode 100644 index 97570721b0d62..0000000000000 --- a/onnxruntime/core/providers/rocm/fpgeneric.cu +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include "core/providers/rocm/cu_inc/common.cuh" - -#define TRANS_TILE_DIM 32 -#define BLOCK_ROWS 8 -#define COPY_TILE_DIM 1024 -#define COPY_BLOCK_DIM 256 - -// kernel(s) for half functions with no library support -namespace { - -__global__ void transposeNoOverlap(half* odata, const half* idata, const int m, const int n) { - __shared__ half tile[TRANS_TILE_DIM][TRANS_TILE_DIM + 1]; - - int x = blockIdx.x * TRANS_TILE_DIM + threadIdx.x; - int y = blockIdx.y * TRANS_TILE_DIM + threadIdx.y; - - if (x < m) { - for (int j = 0; j < TRANS_TILE_DIM; j += BLOCK_ROWS) { - if (j >= (n - y)) continue; - tile[threadIdx.y + j][threadIdx.x] = idata[(y + j) * m + x]; - } - } - - __syncthreads(); - - x = blockIdx.y * TRANS_TILE_DIM + threadIdx.x; // transpose block offset - y = blockIdx.x * TRANS_TILE_DIM + threadIdx.y; - - if (x >= n) return; - - for (int j = 0; j < TRANS_TILE_DIM; j += BLOCK_ROWS) { - if ((y + j) >= m) return; - odata[(y + j) * n + x] = tile[threadIdx.x][threadIdx.y + j]; - } -} - -__global__ void CopyVectorHalf(const half* x, int incx, half* y, int incy, int n) { - int id = blockIdx.x * blockDim.x + threadIdx.x; - if (id >= n) return; - y[id * incy] = x[id * incx]; -} - -__global__ void CopyVectorBFloat16(const onnxruntime::BFloat16* x, int incx, onnxruntime::BFloat16* y, int incy, - int n) { - int id = blockIdx.x * blockDim.x + threadIdx.x; - if (id >= n) return; - y[id * incy] = x[id * incx]; -} - -} // namespace - -dim3 hipblasTransposeHelperDimGrid(int m, int n) { - return dim3((n + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, (m + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, 1); -} - -// hipblasTransposeHelper can only be used if it won't overflow the maxGridSize y dimension size -__host__ bool CanUse_hipblasTransposeHelper_MLFloat16(int m, int n) { - dim3 dimGrid = hipblasTransposeHelperDimGrid(m, n); - - int deviceId; - hipError_t hipError = hipGetDevice(&deviceId); - if (hipError != 0) return false; - - hipDeviceProp_t deviceProp; - hipError = hipGetDeviceProperties(&deviceProp, deviceId); - if (hipError != 0) return false; - - return dimGrid.y < deviceProp.maxGridSize[1]; -} - -hipblasStatus_t hipblasTransposeHelper(hipStream_t stream, hipblasHandle_t, hipblasOperation_t , hipblasOperation_t , int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) { -if (C != A) { - dim3 dimGrid((n + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, (m + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, 1); - dim3 dimBlock(TRANS_TILE_DIM, BLOCK_ROWS, 1); - - transposeNoOverlap<<>>(C, A, n, m); - } else { - return HIPBLAS_STATUS_NOT_SUPPORTED; - } - return HIPBLAS_STATUS_SUCCESS; -} - -hipblasStatus_t hipblasCopyHelper(hipStream_t stream, hipblasHandle_t, int n, const half* x, int incx, half* y, int incy) { - dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); - dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); - CopyVectorHalf<<>>(x, incx, y, incy, n); - return HIPBLAS_STATUS_SUCCESS; -} - -hipblasStatus_t hipblasCopyHelper(hipStream_t stream, hipblasHandle_t, int n, const onnxruntime::BFloat16* x, int incx, - onnxruntime::BFloat16* y, int incy) { - dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); - dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); - CopyVectorBFloat16<<>>(x, incx, y, incy, n); - return HIPBLAS_STATUS_SUCCESS; -} diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc deleted file mode 100644 index 2c593a4adc41b..0000000000000 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" - -#include "core/providers/rocm/gpu_data_transfer.h" -#include "core/providers/rocm/rocm_common.h" - -// If you make change below, please also update onnxruntime/core/providers/migraphx/gpu_data_transfer.cc -namespace onnxruntime { - -bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE || - dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE; -} - -common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { - size_t bytes = src.SizeInBytes(); - const void* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); - - auto& src_device = src.Location().device; - auto& dst_device = dst.Location().device; - - // for the sync version of memcpy, launch to hip default stream - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::GPU) { - // Copy only if the two addresses are different. - if (dst_data != src_data) { - HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); - // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); - } - } else { - // copy from other CPU memory to GPU, this is blocking - HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); - if (src_device.MemType() != OrtDevice::MemType::HOST_ACCESSIBLE) { - // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); - } - } - } else if (src_device.Type() == OrtDevice::GPU) { - // copying from GPU to CPU memory, this is blocking - HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); - } else { - // copying between cpu memory - ORT_ENFORCE(dst_data != src_data); - memcpy(dst_data, src_data, bytes); - } - - return Status::OK(); -} - -common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const { - size_t bytes = src.SizeInBytes(); - const void* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); - - auto& src_device = src.Location().device; - auto& dst_device = dst.Location().device; - - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::CPU) { - // If source are not pinned, the memory copy will be performed synchronously. - // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, - static_cast(stream.GetHandle()))); - } else if (src_device.Type() == OrtDevice::GPU) { - // copying between GPU, this is non-blocking - if (dst_data != src_data) { - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, - static_cast(stream.GetHandle()))); - } - } - } else if (src_device.Type() == OrtDevice::GPU) { - // If dest are not pinned, the memory copy will be performed synchronously. - // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, - static_cast(stream.GetHandle()))); - } else { - if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) { - // sync the stream first to make sure the data arrived - HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); - } - ORT_ENFORCE(dst_data != src_data); - memcpy(dst_data, src_data, bytes); - } - - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.h b/onnxruntime/core/providers/rocm/gpu_data_transfer.h deleted file mode 100644 index 3d35ed52fff5c..0000000000000 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/framework/data_transfer.h" -#include "core/providers/rocm/rocm_pch.h" - -namespace onnxruntime { - -class GPUDataTransfer : public IDataTransfer { - public: - GPUDataTransfer() = default; - ~GPUDataTransfer() = default; - - bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; - - // Dumpen MSVC warning about not fully overriding - using IDataTransfer::CopyTensor; - common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; - common::Status CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const override; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/integer_gemm.cc b/onnxruntime/core/providers/rocm/integer_gemm.cc deleted file mode 100644 index 2d6ee89239cee..0000000000000 --- a/onnxruntime/core/providers/rocm/integer_gemm.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/shared_inc/integer_gemm.h" - -#include "core/common/safeint.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/rocm_call.h" - -namespace onnxruntime { -namespace rocm { - -constexpr int roundoff(int v, int d) { - return (v + d - 1) / d * d; -} - -Status GemmInt8(int m, int n, int k, - int32_t alpha, int32_t beta, - const int8_t* a, int lda, const int8_t* b, int ldb, int32_t* c, int ldc, - const RocmKernel* rocm_kernel, onnxruntime::Stream* ort_stream) { - ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null"); - ORT_ENFORCE(rocm_kernel != nullptr, "kernel is null"); - ORT_ENFORCE(ort_stream != nullptr, "Rocm kernel must have the stream instance"); - - hipStream_t stream = static_cast(ort_stream->GetHandle()); - - // pad A and B to make their leading dimension be multiples of 32 - // because hipblasGemmEx requires: - // 1. leading dimension is multiples of 4 - // 2. A, B is 32-bit aligned - - constexpr int mask = 0x1F; - int lda_aligned = lda; - IAllocatorUniquePtr a_padded; - if ((mask & lda_aligned) != 0) { - lda_aligned = roundoff(lda, 32); - a_padded = rocm_kernel->GetScratchBuffer(SafeInt(m) * lda_aligned, ort_stream); - HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, hipMemcpyDeviceToDevice, stream)); - } - - int ldb_aligned = ldb; - IAllocatorUniquePtr b_padded; - if ((mask & ldb_aligned) != 0) { - ldb_aligned = roundoff(ldb, 32); - b_padded = rocm_kernel->GetScratchBuffer(SafeInt(k) * ldb_aligned, ort_stream); - HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, hipMemcpyDeviceToDevice, stream)); - } - - auto* ort_rocm_stream = dynamic_cast(ort_stream); - auto hipblas = ort_rocm_stream->hipblas_handle_; - - HIPBLAS_RETURN_IF_ERROR(hipblasGemmEx( - hipblas, - HIPBLAS_OP_N, HIPBLAS_OP_N, - n, m, k, - &alpha, - ldb_aligned == ldb ? b : b_padded.get(), HIP_R_8I, ldb_aligned, - lda_aligned == lda ? a : a_padded.get(), HIP_R_8I, lda_aligned, - &beta, - c, HIP_R_32I, ldc, - HIPBLAS_COMPUTE_32I, - HIPBLAS_GEMM_DEFAULT)); - return Status::OK(); -} -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/einsum.cc b/onnxruntime/core/providers/rocm/math/einsum.cc deleted file mode 100644 index 808ca2a31cc4e..0000000000000 --- a/onnxruntime/core/providers/rocm/math/einsum.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "einsum.h" - -namespace onnxruntime { - -// This function must exist due to the C++ base class constructor needing this to be defined for the vtable, but it is never called. -Status Einsum::DeviceCompute(OpKernelContext* /*context*/, const std::vector& /*inputs*/, - AllocatorPtr /*allocator*/, concurrency::ThreadPool* /*tp*/) const { - assert(false); - return Status::OK(); -} - -namespace rocm { - -ONNX_OPERATOR_KERNEL_EX( - Einsum, - kOnnxDomain, - 12, - kRocmExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - Einsum); - -Status Einsum::Compute(OpKernelContext* context) const { - return onnxruntime::Einsum::Compute(context); -} - -Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector& inputs, - AllocatorPtr allocator, concurrency::ThreadPool* tp) const { - auto* stream = context->GetComputeStream(); - ORT_RETURN_IF(!stream, "stream is null"); - auto* rocm_stream = static_cast(stream); - hipblasHandle_t hipblas_handle = rocm_stream ? rocm_stream->hipblas_handle_ : nullptr; - EinsumOp::EinsumRocmAssets einsum_rocm_assets(hipblas_handle, rocm_ep_, stream, Info().GetAllocator(OrtMemType::OrtMemTypeDefault)); - - // EinsumComputePreprocessor section - - auto einsum_compute_preprocessor = EinsumComputePreprocessor::Create(*einsum_equation_preprocessor_, inputs, allocator, - &einsum_rocm_assets); - - einsum_compute_preprocessor->SetDeviceHelpers(EinsumOp::DeviceHelpers::RocmDeviceHelpers::Diagonal, - EinsumOp::DeviceHelpers::RocmDeviceHelpers::Transpose); - // Compute all required metadata to be used at Einsum compute time and return error status code if one was generated - ORT_RETURN_IF_ERROR(einsum_compute_preprocessor->Run()); - - // EinsumComputeProcessor section - - if (inputs[0]->IsDataType()) { - auto einsum_compute_processor = EinsumTypedComputeProcessor::Create(context, allocator, tp, - *einsum_compute_preprocessor, - &einsum_rocm_assets); - - einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::RocmDeviceHelpers::Transpose, - EinsumOp::DeviceHelpers::RocmDeviceHelpers::MatMul, - EinsumOp::DeviceHelpers::RocmDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::RocmDeviceHelpers::DataCopy); - return einsum_compute_processor->Run(); - } else if (inputs[0]->IsDataType()) { - auto einsum_compute_processor = EinsumTypedComputeProcessor::Create(context, allocator, tp, - *einsum_compute_preprocessor, - &einsum_rocm_assets); - - einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::RocmDeviceHelpers::Transpose, - EinsumOp::DeviceHelpers::RocmDeviceHelpers::MatMul, - EinsumOp::DeviceHelpers::RocmDeviceHelpers::ReduceSum, - EinsumOp::DeviceHelpers::RocmDeviceHelpers::DataCopy); - return einsum_compute_processor->Run(); - } - - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "Einsum op: An implementation for the input type ", - inputs[0]->DataType(), " is not supported yet"); -} - -} // namespace rocm - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/einsum.h b/onnxruntime/core/providers/rocm/math/einsum.h deleted file mode 100644 index c62e219a66499..0000000000000 --- a/onnxruntime/core/providers/rocm/math/einsum.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/platform/threadpool.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/cpu/math/einsum.h" -#include "einsum_utils/einsum_auxiliary_ops.h" -#include "core/providers/rocm/rocm_execution_provider.h" - -namespace onnxruntime { -namespace rocm { - -class Einsum final : public onnxruntime::Einsum { - public: - Einsum(const OpKernelInfo& info) : onnxruntime::Einsum(info) { - // We need to cast away the const as PerThreadHipblasHandle() is currently a non-const method - // TODO: Clean up the ROCMExecutionProvider interface to avoid this - rocm_ep_ = static_cast(info.GetExecutionProvider()); - } - - Status Compute(OpKernelContext* context) const override; - - private: - Status DeviceCompute(OpKernelContext* context, const std::vector& inputs, - AllocatorPtr allocator, concurrency::ThreadPool* tp) const override; - - // Members of Einsum ROCM kernel - using onnxruntime::Einsum::einsum_equation_preprocessor_; - using onnxruntime::Einsum::equation_; - - // We need to access to the ROCM EP instance to get the hipblas/miopen handles - const ROCMExecutionProvider* rocm_ep_; -}; - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.cc deleted file mode 100644 index 553fe1dccb332..0000000000000 --- a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/tunable/gemm.h" - -namespace onnxruntime { -namespace concurrency { -class ThreadPool; -} -} // namespace onnxruntime - -#include "einsum_auxiliary_ops.h" - -namespace onnxruntime { - -namespace EinsumOp { - -namespace DeviceHelpers { - -namespace RocmDeviceHelpers { - -// ROCM EP specific Data copy helper -Status DataCopy(const Tensor& input, Tensor& output, void* einsum_rocm_assets) { - ORT_ENFORCE(output.SizeInBytes() == input.SizeInBytes(), - "Einsum op: The candidate output does not match the actual output's shape"); - // There are no string tensors in Einsum's case - so safely use memcpy - HIP_RETURN_IF_ERROR(hipMemcpyAsync(output.MutableDataRaw(), input.DataRaw(), input.SizeInBytes(), - hipMemcpyDeviceToDevice, - static_cast(einsum_rocm_assets)->GetRocmStream())); - - return Status::OK(); -} - -// ROCM EP specific Transpose helper -Status Transpose(const gsl::span& permutation, const Tensor& input, - Tensor& output, const TensorShape* input_shape_override, void* einsum_rocm_assets) { - return rocm::Transpose::DoTranspose(static_cast(einsum_rocm_assets)->rocm_ep_->GetDeviceProp(), - static_cast(einsum_rocm_assets)->GetRocmStream(), - static_cast(einsum_rocm_assets)->hipblas_handle_, - permutation, input, output, input_shape_override); -} - -// ROCM EP specific MatMul helper -template -Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data, - size_t left_stride, size_t right_stride, size_t output_stride, - size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* /*tp*/, - void* einsum_rocm_assets) { - typedef typename rocm::ToHipType::MappedType HipT; - - namespace blas = rocm::tunable::blas; - return blas::column_major::StridedBatchedGemm( - static_cast( - static_cast(einsum_rocm_assets)->rocm_ep_->GetTuningContext()), - static_cast(einsum_rocm_assets)->ort_stream_, - static_cast(einsum_rocm_assets)->hipblas_handle_, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - N, M, K, - /*alpha=*/1.0f, - reinterpret_cast(input_2_data), N, right_stride, - reinterpret_cast(input_1_data), K, left_stride, - /*beta=*/0.0f, - reinterpret_cast(output_data), N, output_stride, - num_batches); -} - -// ROCM EP specific ReduceSum helper -template -std::unique_ptr ReduceSum(const Tensor& input, gsl::span reduce_axes, - bool keep_dims, AllocatorPtr allocator, - const TensorShape* input_shape_override, - concurrency::ThreadPool* /*tp*/, void* einsum_rocm_assets) { - return rocm::ReductionOps::ReduceCompute(static_cast(einsum_rocm_assets)->gpu_allocator_, MIOPEN_REDUCE_TENSOR_ADD, - allocator, input, reduce_axes, - keep_dims, false, false, false, - true, static_cast(einsum_rocm_assets)->ort_stream_, - input_shape_override); -} - -// ROCM EP specific Diagonal helper -std::unique_ptr Diagonal(const Tensor& input, int64_t dim_1, int64_t dim_2, AllocatorPtr allocator, void* einsum_rocm_assets) { - const auto& input_shape = input.Shape(); - const auto& input_dims = input_shape.GetDims(); - auto rank = static_cast(input_dims.size()); - - ORT_ENFORCE(rank >= 2 && dim_1 != dim_2 && input_dims[dim_1] == input_dims[dim_2], - "Cannot parse the diagonal elements along dims ", dim_1, " and ", dim_2, " for input shape ", input_shape); - - int64_t first_dim = -1; // first_dim holds the lesser of dim_1 and dim_2 - int64_t second_dim = -1; // second_dim holds the greater of dim_1 and dim_2 - if (dim_1 < dim_2) { - first_dim = dim_1; - second_dim = dim_2; - } else { - first_dim = dim_2; - second_dim = dim_1; - } - - // Make a copy - we are going to mutate the dims - TensorShapeVector output_dims = input_shape.AsShapeVector(); - - // Remove the dim value in `second_dim` - - // The diagonal values are stored along `first_dim` - output_dims.erase(output_dims.begin() + second_dim); - - auto output = Tensor::Create(input.DataType(), output_dims, allocator); - - TensorPitches input_strides(input.Shape().GetDims()); - rocm::TArray gpu_input_strides(input_strides); - - auto output_rank = static_cast(output_dims.size()); - rocm::TArray gpu_output_strides(output_rank); - TensorPitches output_strides(output_dims); - for (auto i = 0; i < output_rank; i++) { - gpu_output_strides[i] = rocm::fast_divmod(static_cast(output_strides[i])); - } - - DiagonalImpl( - static_cast(einsum_rocm_assets)->GetRocmStream(), - input.DataRaw(), - input.Shape().GetDims().size(), - first_dim, - second_dim, - gpu_input_strides, - output->MutableDataRaw(), - gpu_output_strides, - TensorShape(output_dims).Size(), - input.DataType()->Size()); - - return output; -} - -} // namespace RocmDeviceHelpers - -} // namespace DeviceHelpers - -// Explicit template instantiations of functions - -// float -template Status DeviceHelpers::RocmDeviceHelpers::MatMul( - const float* input_1_data, const float* input_2_data, float* output_data, - size_t left_stride, size_t right_stride, size_t output_stride, - size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp, - void* einsum_rocm_assets); - -template std::unique_ptr DeviceHelpers::RocmDeviceHelpers::ReduceSum( - const Tensor& input, gsl::span reduce_axes, - bool keep_dims, AllocatorPtr allocator, - const TensorShape* input_shape_override, - concurrency::ThreadPool* tp, void* einsum_rocm_assets); - -// MLFloat16 -template Status DeviceHelpers::RocmDeviceHelpers::MatMul( - const MLFloat16* input_1_data, const MLFloat16* input_2_data, MLFloat16* output_data, - size_t left_stride, size_t right_stride, size_t output_stride, - size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp, - void* einsum_rocm_assets); - -template std::unique_ptr DeviceHelpers::RocmDeviceHelpers::ReduceSum( - const Tensor& input, gsl::span reduce_axes, - bool keep_dims, AllocatorPtr allocator, - const TensorShape* input_shape_override, - concurrency::ThreadPool* tp, void* einsum_rocm_assets); - -} // namespace EinsumOp - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h deleted file mode 100644 index 689c65ae29f82..0000000000000 --- a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// This module hosts implementations and thin wrappers over other onnx operator implementations -// that will be called from within the Einsum operator implementation - -#pragma once - -#include "core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h" -#include "core/providers/rocm/tensor/transpose.h" -#include "core/providers/rocm/reduction/reduction_ops.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/cpu/tensor/utils.h" -#include "einsum_auxiliary_ops_diagonal.h" -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { - -namespace EinsumOp { - -// Holds ROCM assets required for ROCM ops that need to be executed as part of the Einsum flow -struct EinsumRocmAssets { - explicit EinsumRocmAssets(hipblasHandle_t hipblas_handle, - const ROCMExecutionProvider* rocm_ep, - Stream* ort_stream, AllocatorPtr gpu_allocator) : hipblas_handle_(hipblas_handle), - rocm_ep_(rocm_ep), - ort_stream_(ort_stream), - gpu_allocator_(gpu_allocator) {} - - hipStream_t GetRocmStream() { - return ort_stream_ ? static_cast(ort_stream_->GetHandle()) : nullptr; - } - - hipblasHandle_t hipblas_handle_; - const ROCMExecutionProvider* rocm_ep_; - Stream* ort_stream_; - AllocatorPtr gpu_allocator_; -}; - -namespace DeviceHelpers { - -// These are ROCM EP specific device helper implementations -namespace RocmDeviceHelpers { - -Status Transpose(const gsl::span& permutation, const Tensor& input, - Tensor& output, const TensorShape* input_shape_override, void* einsum_rocm_assets); - -Status DataCopy(const Tensor& input, Tensor& output, void* einsum_rocm_assets); - -template -Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data, - size_t left_stride, size_t right_stride, size_t output_stride, - size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp, - void* einsum_rocm_assets); - -template -std::unique_ptr ReduceSum(const Tensor& input, gsl::span reduce_axes, - bool keep_dims, AllocatorPtr allocator, - const TensorShape* input_shape_override, - concurrency::ThreadPool* /*tp*/, void* einsum_rocm_assets); - -std::unique_ptr Diagonal(const Tensor& input, int64_t dim_1, int64_t dim_2, AllocatorPtr allocator, void* einsum_rocm_assets); - -} // namespace RocmDeviceHelpers - -} // namespace DeviceHelpers - -} // namespace EinsumOp - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu deleted file mode 100644 index e1c89a386dafc..0000000000000 --- a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/cu_inc/common.cuh" -#include "einsum_auxiliary_ops_diagonal.h" - -namespace onnxruntime { - -namespace rocm { - -template -__global__ void _DiagonalKernel( - const T* input_data, - const int64_t input_rank, - const int64_t dim_1, - const int64_t dim_2, - const TArray input_strides, - T* output_data, - const TArray output_strides, - const size_t output_size) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(output_idx, output_size); - int dim = 0; - int remain = output_idx; - size_t input_idx = 0; - int64_t current_input_axis = 0; - - // Output's rank is always 1 less than the input's rank - for (int i = 0; i < input_rank - 1; ++i) { - output_strides[i].divmod(remain, dim, remain); - if (i == dim_1) { - // Process dim_2 as dim_2 needs to have the same dim value as dim_1 - // For example: given a tensor of shape [2, 3, 3] and parsing the diagonal along axes `1` and `2` - // we need to parse elements in input[j, i, i] (j -> 0 to 1; and i -> 0 to 2) - // and place them in output[j, i] and by definition of diagonal parsing dim_1 has to be equal to - // dim_2 - input_idx += input_strides[dim_2] * dim; - } - input_idx += input_strides[current_input_axis] * dim; - - // Update current_input_axis - // If it is dim_2, skip it - if (++current_input_axis == dim_2) { - ++current_input_axis; - } - } - output_data[output_idx] = input_data[input_idx]; -} - -void DiagonalImpl( - hipStream_t stream, - const void* input_data, - const int64_t input_rank, - const int64_t dim_1, - const int64_t dim_2, - const TArray input_strides, - void* output_data, - const TArray output_strides, - const size_t output_size, - size_t element_size) { - if (output_size > 0) { - int blocksPerGrid = static_cast((output_size + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock); - - switch (element_size) { - case sizeof(int32_t): - _DiagonalKernel<<>>( - reinterpret_cast::MappedType*>(input_data), input_rank, dim_1, dim_2, - input_strides, reinterpret_cast::MappedType*>(output_data), output_strides, - output_size); - break; - - case sizeof(int64_t): - _DiagonalKernel<<>>( - reinterpret_cast::MappedType*>(input_data), input_rank, dim_1, dim_2, - input_strides, reinterpret_cast::MappedType*>(output_data), output_strides, - output_size); - break; - - case sizeof(int16_t): - _DiagonalKernel<<>>( - reinterpret_cast(input_data), input_rank, dim_1, dim_2, - input_strides, reinterpret_cast(output_data), output_strides, - output_size); - break; - - // Should not hit this as we do not register kernel support for types that will run into this - default: - ORT_THROW("Einsum Op: Diagonal parsing unsupported"); - } - } -} - -} // namespace rocm - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.h b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.h deleted file mode 100644 index 4742b5338ec16..0000000000000 --- a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/shared_inc/rocm_utils.h" - -namespace onnxruntime { - -namespace rocm { - -void DiagonalImpl( - hipStream_t stream, - const void* input_data, - const int64_t input_rank, - const int64_t dim_1, - const int64_t dim_2, - const TArray input_strides, - void* output_data, - const TArray output_strides, - const size_t output_size, - size_t element_size); - -} // namespace rocm - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/gemm.cc b/onnxruntime/core/providers/rocm/math/gemm.cc deleted file mode 100644 index 529b48f736d50..0000000000000 --- a/onnxruntime/core/providers/rocm/math/gemm.cc +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/math/gemm.h" - -#include "core/providers/cpu/math/gemm_helper.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm.h" - -namespace onnxruntime { -namespace rocm { - -using tunable::blas::BlasOp; - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Gemm, \ - kOnnxDomain, \ - 7, \ - 8, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Gemm); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Gemm, \ - kOnnxDomain, \ - 9, \ - 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Gemm); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Gemm, \ - kOnnxDomain, \ - 11, \ - 12, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Gemm); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Gemm, \ - kOnnxDomain, \ - 13, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Gemm); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(double) -REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) - -template -Status Gemm::ComputeInternal(OpKernelContext* ctx) const { - typedef typename ToHipType::MappedType HipT; - - const auto* X = ctx->Input(0); - const auto* W = ctx->Input(1); - const auto* B = ctx->Input(2); - // Bias could be missing. Treat as scalar 0 if that is the case. - GemmHelper helper(X->Shape(), trans_A_, W->Shape(), trans_B_, B != nullptr ? B->Shape() : TensorShape({})); - - if (!helper.State().IsOK()) - return helper.State(); - - ptrdiff_t M = helper.M(); - ptrdiff_t N = helper.N(); - ptrdiff_t K = helper.K(); - auto* Y = ctx->Output(0, {M, N}); - HipT* out_data = reinterpret_cast(Y->MutableData()); - - // broadcast bias if needed and is present - if (beta_ != 0 && B != nullptr) { - auto& b_shape = B->Shape(); - const HipT* b_data = reinterpret_cast(B->Data()); - - if (b_shape.Size() == 1) { - // if B is (), (1,) or (1, 1), broadcast the scalar - HIPBLAS_RETURN_IF_ERROR(hipblasCopyHelper( - Stream(ctx), - GetHipblasHandle(ctx), - M * N, - b_data, - 0, - out_data, - 1)); - } else if (b_shape.NumDimensions() == 1 || b_shape[0] == 1) { - // B is (N,) or (1, N), broadcast using Y(N,M) = 1 * B(N,1) x ones(1,M) + 0 * Y - ORT_RETURN_IF_ERROR(tunable::blas::column_major::Gemm( - GetTuningContext(), ctx->GetComputeStream(), GetHipblasHandle(ctx), - tunable::blas::BlasOp::NonTrans, - tunable::blas::BlasOp::NonTrans, - N, M, 1, - /*alpha=*/1.0f, - b_data, N, - GetConstOnes(M, Stream(ctx)), 1, - /*beta=*/0.0f, - out_data, N)); - } else if (b_shape.NumDimensions() == 2 && b_shape[1] == 1) { - // B is (M, 1), broadcast using Y(N,M) = 1 * ones(N,1) x B(1,M) + 0 * Y - ORT_RETURN_IF_ERROR(tunable::blas::column_major::Gemm( - GetTuningContext(), ctx->GetComputeStream(), GetHipblasHandle(ctx), - tunable::blas::BlasOp::NonTrans, - tunable::blas::BlasOp::NonTrans, - N, M, 1, - /*alpha=*/1.0f, - GetConstOnes(N, Stream(ctx)), N, - b_data, 1, - /*beta=*/0.0f, - out_data, N)); - } else { - // B is (M, N), no broadcast needed. - HIP_RETURN_IF_ERROR(hipMemcpyAsync(out_data, b_data, M * N * sizeof(T), hipMemcpyDeviceToDevice, Stream(ctx))); - } - } - - return tunable::blas::column_major::Gemm( - GetTuningContext(), ctx->GetComputeStream(), - GetHipblasHandle(ctx), - trans_B_ ? BlasOp::Trans : BlasOp::NonTrans, - trans_A_ ? BlasOp::Trans : BlasOp::NonTrans, - N, M, K, - alpha_, - reinterpret_cast(W->Data()), (trans_B_ ? K : N), - reinterpret_cast(X->Data()), (trans_A_ ? M : K), - // ideally we need to set the output buffer contents to 0 if bias is missing, - // but passing 0 for beta is cheaper and it will ignore any junk in the output buffer - B != nullptr ? beta_ : 0.0f, - out_data, N); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/matmul.cc b/onnxruntime/core/providers/rocm/math/matmul.cc deleted file mode 100644 index 1aad3fa6e1fcc..0000000000000 --- a/onnxruntime/core/providers/rocm/math/matmul.cc +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/math/matmul.h" - -#include "core/providers/cpu/math/matmul_helper.h" -#include "core/providers/rocm/math/matmul_impl.h" - -namespace onnxruntime { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - MatMul, \ - kOnnxDomain, \ - 1, 8, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - MatMul); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - MatMul, \ - kOnnxDomain, \ - 9, 12, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - MatMul); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - MatMul, \ - kOnnxDomain, \ - 13, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - MatMul); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(double) -REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) - -template -Status MatMul::ComputeInternal(OpKernelContext* ctx) const { - const Tensor* left_X = ctx->Input(0); - const Tensor* right_X = ctx->Input(1); - - // Ignore the transpose flag if rank of input being 1. - // Be noted: numpy.transpose on vector does not change anything. - bool transa = trans_A_; - bool transb = trans_B_; - if (left_X->Shape().NumDimensions() == 1) { - transa = false; - } - if (right_X->Shape().NumDimensions() == 1) { - transb = false; - } - - MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(left_X->Shape(), right_X->Shape(), transa, - transb, trans_batch_a_, trans_batch_b_, - false)); - - Tensor* Y = ctx->Output(0, helper.OutputShape()); - - // Bail out early if the output is going to be empty - if (Y->Shape().Size() == 0) return Status::OK(); - - return MatMulImpl(this, helper, reinterpret_cast(left_X->Data()), - reinterpret_cast(right_X->Data()), - reinterpret_cast(Y->MutableData()), - left_X->Shape(), right_X->Shape(), - transa, transb, trans_batch_a_, trans_batch_b_, alpha_, ctx->GetComputeStream()); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/matmul_impl.cc b/onnxruntime/core/providers/rocm/math/matmul_impl.cc deleted file mode 100644 index e27a7e7575da7..0000000000000 --- a/onnxruntime/core/providers/rocm/math/matmul_impl.cc +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/math/matmul_impl.h" - -#include "core/providers/rocm/rocm_allocator.h" -#include "core/providers/rocm/rocm_kernel.h" -#include "core/providers/rocm/tunable/gemm.h" - -namespace onnxruntime { -namespace rocm { - -// StridedBatchedGemm can be used for the following GEMM computation -// C[pnm] = A[pnk]*B[km] or C[pnm] = A[pnk]*B[pkm] -static bool CanUseStridedBatchedGemm(const TensorShape& left_shape, - const TensorShape& right_shape, - bool transa, bool transb, - bool trans_batch_a, bool trans_batch_b, - int64_t& stride_A, int64_t& stride_B, - int64_t& stride_C, int64_t& batch_count) { - size_t left_num_dims = left_shape.NumDimensions(); - size_t right_num_dims = right_shape.NumDimensions(); - - if (!(left_num_dims >= 3 && right_num_dims >= 2)) { - return false; - } - - size_t left_leading_axis = trans_batch_a ? 0 : left_num_dims - 2; - size_t right_leading_axis = trans_batch_b ? 0 : right_num_dims - 2; - int64_t left_p = left_shape.SizeToDimension(left_num_dims - 2); - if (trans_batch_a) { - left_p = left_p * left_shape[left_num_dims - 2] / left_shape[0]; - } - int64_t left_k = - transa ? left_shape[left_leading_axis] : left_shape[left_num_dims - 1]; - - if (right_num_dims >= 3) { - int64_t right_p = right_shape.SizeToDimension(right_num_dims - 2); - if (trans_batch_b) { - right_p = right_p * right_shape[right_num_dims - 2] / right_shape[0]; - } - if (left_p != right_p) { - return false; - } - } - - int64_t right_k = transb ? right_shape[right_num_dims - 1] - : right_shape[right_leading_axis]; - if (left_k != right_k) { - return false; - } - - int64_t n = - transa ? left_shape[left_num_dims - 1] : left_shape[left_leading_axis]; - int64_t m = transb ? right_shape[right_leading_axis] - : right_shape[right_num_dims - 1]; - stride_A = n * left_k / (trans_batch_a ? left_shape[0] : 1); - stride_B = right_num_dims == 2 ? 0 : right_k * m / (trans_batch_b ? right_shape[0] : 1); - stride_C = n * m; - batch_count = left_p; - return true; -} - -template -Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper, - const T* left_x_data, const T* right_x_data, T* output_y_data, - const TensorShape& left_shape, const TensorShape& right_shape, - bool transa, bool transb, bool trans_batch_a, bool trans_batch_b, - const float alpha, onnxruntime::Stream* stream) { - typedef typename ToHipType::MappedType HipT; - - using tunable::blas::BlasOp; - BlasOp transA = transa ? BlasOp::Trans : BlasOp::NonTrans; - BlasOp transB = transb ? BlasOp::Trans : BlasOp::NonTrans; - - const int lda = helper.Lda(transa); - const int ldb = helper.Ldb(transb); - const int ldc = helper.Ldc(); - int64_t stride_A, stride_B, stride_C, batch_count; - - auto hipblasHandle_t = op->GetHipblasHandle(static_cast(stream)); - - if (helper.OutputOffsets().size() == 1) { - return tunable::blas::column_major::Gemm( - op->GetTuningContext(), stream, hipblasHandle_t, - transB, transA, - helper.N(), helper.M(), helper.K(), - alpha, - reinterpret_cast(right_x_data), ldb, - reinterpret_cast(left_x_data), lda, - /*beta=*/0.0f, - reinterpret_cast(output_y_data), ldc); - } else if (CanUseStridedBatchedGemm(left_shape, right_shape, - transa, transb, trans_batch_a, trans_batch_b, - stride_A, stride_B, stride_C, batch_count)) { - return tunable::blas::column_major::StridedBatchedGemm( - op->GetTuningContext(), stream, hipblasHandle_t, - transB, transA, - helper.N(), helper.M(), helper.K(), - alpha, - reinterpret_cast(right_x_data), ldb, stride_B, - reinterpret_cast(left_x_data), lda, stride_A, - /*beta=*/0.0f, - reinterpret_cast(output_y_data), ldc, stride_C, - batch_count); - } - - // Fill offsets when needed. - helper.FillOffsets(); - RocmKernel::RocmAsyncBuffer left_arrays(op, helper.LeftOffsets().size()); - RocmKernel::RocmAsyncBuffer right_arrays(op, helper.RightOffsets().size()); - RocmKernel::RocmAsyncBuffer output_arrays(op, helper.OutputOffsets().size()); - MatMulComputeHelper::OffsetToArrays( - reinterpret_cast(left_x_data), - helper.LeftOffsets(), left_arrays.CpuSpan()); - MatMulComputeHelper::OffsetToArrays( - reinterpret_cast(right_x_data), - helper.RightOffsets(), right_arrays.CpuSpan()); - MatMulComputeHelper::OffsetToArrays( - reinterpret_cast(output_y_data), - helper.OutputOffsets(), output_arrays.CpuSpan()); - ORT_RETURN_IF_ERROR(left_arrays.CopyToGpu(stream)); - ORT_RETURN_IF_ERROR(right_arrays.CopyToGpu(stream)); - ORT_RETURN_IF_ERROR(output_arrays.CopyToGpu(stream)); - - // note that onnxruntime OrtValue is row major, while hipblas is column major, - // so swap left/right operands - return tunable::blas::column_major::BatchedGemm( - op->GetTuningContext(), stream, hipblasHandle_t, - transB, transA, - helper.N(), helper.M(), helper.K(), - alpha, - right_arrays.GpuPtr(), ldb, - left_arrays.GpuPtr(), lda, - /*beta=*/0.0f, - output_arrays.GpuPtr(), ldc, - static_cast(helper.OutputOffsets().size())); -} - -#define SPECIALIZED_IMPL(T) \ - template Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper, \ - const T* left_x_data, const T* right_x_data, T* output_y_data, \ - const TensorShape& left_shape, const TensorShape& right_shape, \ - bool transa, bool transb, \ - bool trans_batch_a, bool trans_batch_b, \ - const float t_alpha, onnxruntime::Stream* stream); - -SPECIALIZED_IMPL(float) -SPECIALIZED_IMPL(double) -SPECIALIZED_IMPL(MLFloat16) -SPECIALIZED_IMPL(BFloat16) - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/matmul_impl.h b/onnxruntime/core/providers/rocm/math/matmul_impl.h deleted file mode 100644 index d0e13a34023b9..0000000000000 --- a/onnxruntime/core/providers/rocm/math/matmul_impl.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/cpu/math/matmul_helper.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace rocm { - -template -Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper, - const T* left_x_data, const T* right_x_data, T* output_y_data, - const TensorShape& left_shape, const TensorShape& right_shape, - bool transa, bool transb, bool trans_batch_a, bool trans_batch_b, - const float alpha, onnxruntime::Stream* stream); - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax.cc b/onnxruntime/core/providers/rocm/math/softmax.cc deleted file mode 100644 index a41934d38177d..0000000000000 --- a/onnxruntime/core/providers/rocm/math/softmax.cc +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/math/softmax.h" - -#include "core/providers/common.h" -#include "core/providers/rocm/miopen_common.h" -#include "core/providers/rocm/shared_inc/accumulation_type.h" -#include "core/providers/rocm/tensor/transpose.h" - -namespace onnxruntime { -namespace rocm { - -template -Status SoftMaxComputeHelper( - Stream* stream, - const T* X, - const TensorShape& input_shape, - TOut* Y, - int64_t axis, - RocmTuningContext* tuning_ctx) { - typedef typename ToHipType::MappedType HipT_IN; - typedef typename ToHipType::MappedType HipT_OUT; - typedef typename ToHipType::MappedType HipT_ACCUM; - - int64_t N = input_shape.SizeToDimension(axis); - int64_t D = input_shape.SizeFromDimension(axis); - auto Y_data = reinterpret_cast(Y); - auto X_data = reinterpret_cast(X); - - if (D <= 1024 && D * sizeof(T) <= 4096) { - return dispatch_warpwise_softmax_forward< - HipT_IN, HipT_OUT, AccumulationType_t, IsLogSoftmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N), tuning_ctx); - } - - return dispatch_blockwise_softmax_forward, IsLogSoftmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(D), - gsl::narrow_cast(N), tuning_ctx); -} - -#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T, TOut) \ - template Status SoftMaxComputeHelper(Stream * stream, const T* input, \ - const TensorShape& shape, TOut* Y, int64_t axis, \ - RocmTuningContext* tuning_ctx); \ - template Status SoftMaxComputeHelper(Stream * stream, const T* input, \ - const TensorShape& shape, TOut* Y, int64_t axis, \ - RocmTuningContext* tuning_ctx); - -SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16, float) -SPECIALIZED_SOFTMAX_HELPER_IMPL(float, float) -// MIOpen double data type not supported -// SPECIALIZED_SOFTMAX_HELPER_IMPL(double, double) -SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16, MLFloat16) -SPECIALIZED_SOFTMAX_HELPER_IMPL(BFloat16, BFloat16) - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Softmax, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Softmax); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Softmax, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Softmax); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Softmax, \ - kOnnxDomain, \ - 13, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Softmax); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - LogSoftmax, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Softmax); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - LogSoftmax, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Softmax); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - LogSoftmax, \ - kOnnxDomain, \ - 13, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Softmax); - -template -Status Softmax::ComputeInternal(OpKernelContext* ctx) const { - const Tensor* X = ctx->Input(0); - const TensorShape& input_shape{X->Shape()}; - size_t rank = input_shape.NumDimensions(); - Tensor* Y = ctx->Output(0, input_shape); - - // special case when there is a dim value of 0 in the shape. - if (input_shape.Size() == 0) - return Status::OK(); - - // handle negative and enforce axis is valid - const size_t axis = static_cast(HandleNegativeAxis(axis_, rank)); - - bool is_transpose_required = false; - std::unique_ptr transposed_input; - std::vector transposed_input_dims; - std::unique_ptr intermediate_output; // output that the softmax implementation will write into while using transposed input - std::vector permutation(rank); - - // The "semantic" meaning of axis has changed in opset-13. - // Please compare: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Softmax - // with https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Softmax-11 for detailed explanations - // To account for the opset-13 behavior, our plan will be to transpose the "axis" dim to the innermost dim - // and perform softmax and then reverse the transpose. We can skip the transposing aspect if the axis is already - // the innermost dim - if (opset_ >= 13 && axis != (rank - 1)) { - is_transpose_required = true; - } - - if (is_transpose_required) { - AllocatorPtr alloc; - auto status = ctx->GetTempSpaceAllocator(&alloc); - if (!status.IsOK()) - return status; - - std::iota(std::begin(permutation), std::end(permutation), 0); - - // swap the innermost dim with the dim corresponding to axis - permutation[axis] = rank - 1; - permutation[rank - 1] = axis; - - transposed_input_dims.reserve(rank); - for (auto e : permutation) { - transposed_input_dims.push_back(input_shape[e]); - } - - // Allocate a temporary tensor to hold transposed input - auto temp_input = Tensor::Create(X->DataType(), TensorShape(transposed_input_dims), alloc); - - // Perform the transpose - ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(), - Stream(ctx), - GetHipblasHandle(ctx), - permutation, *X, *temp_input)); - transposed_input = std::move(temp_input); - - // Allocate memory for the intermediate output - intermediate_output = Tensor::Create(Y->DataType(), TensorShape(transposed_input_dims), alloc); - } - - const T* X_data = nullptr; - T* Y_data = nullptr; - const TensorShape* compute_input_shape = nullptr; - - if (is_transpose_required) { // use intermediate buffers to compute the softmax values - X_data = transposed_input->Data(); - Y_data = intermediate_output->MutableData(); - compute_input_shape = &transposed_input->Shape(); - } else { // use the node input/output directly - X_data = X->Data(); - Y_data = Y->MutableData(); - compute_input_shape = &input_shape; - } - - Status status; - if (log_softmax_) { - status = SoftMaxComputeHelper(ctx->GetComputeStream(), X_data, *compute_input_shape, Y_data, - is_transpose_required ? static_cast(rank) - 1 - : static_cast(axis), - GetTuningContext()); - } else { - status = SoftMaxComputeHelper(ctx->GetComputeStream(), X_data, *compute_input_shape, Y_data, - is_transpose_required ? static_cast(rank) - 1 - : static_cast(axis), - GetTuningContext()); - } - - if (!status.IsOK()) - return status; - - if (is_transpose_required) { - // Perform the transpose to get the axes back to the original ordering - ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(), - Stream(ctx), - GetHipblasHandle(ctx), - permutation, *intermediate_output, *Y)); - } - - return Status::OK(); -} - -#define SPECIALIZED_COMPUTE(T) \ - REGISTER_KERNEL_TYPED(T) \ - template Status Softmax::ComputeInternal(OpKernelContext* ctx) const; - -SPECIALIZED_COMPUTE(float) -// MIOpen double data type not supported -// SPECIALIZED_COMPUTE(double) -SPECIALIZED_COMPUTE(MLFloat16) -SPECIALIZED_COMPUTE(BFloat16) - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax.h b/onnxruntime/core/providers/rocm/math/softmax.h deleted file mode 100644 index 57c1fc5068073..0000000000000 --- a/onnxruntime/core/providers/rocm/math/softmax.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace rocm { - -using tunable::RocmTuningContext; - -template -Status SoftMaxComputeHelper( - Stream* stream, - const T* input, - const TensorShape& shape, - TOut* Y, - int64_t axis, - RocmTuningContext* tuning_ctx = nullptr); - -template -Status dispatch_warpwise_softmax_forward(Stream* stream, OutputT* dst, const InputT* src, int softmax_elements, - int softmax_elements_stride, int batch_count, - RocmTuningContext* tuning_ctx = nullptr); - -template -Status dispatch_blockwise_softmax_forward(Stream* stream, OutputT* output, const InputT* input, int softmax_elements, - int input_stride, int output_stride, int batch_count, - RocmTuningContext* tuning_ctx = nullptr); - -template -class Softmax final : public RocmKernel { - public: - Softmax(const OpKernelInfo& info) : RocmKernel{info} { - const auto& node = info.node(); - opset_ = node.SinceVersion(); - - int64_t axis; - Status status = info.GetAttr("axis", &axis); - - if (status.IsOK()) { - axis_ = gsl::narrow_cast(axis); - } else { - if (opset_ < 13) { - axis_ = 1; // opset-12 and below, the default axis value is 1 - } else { - axis_ = -1; // opset-13, the default axis value is -1 - } - } - - log_softmax_ = info.GetKernelDef().OpName() == "LogSoftmax"; - } - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - int64_t axis_; - bool log_softmax_; - int opset_; -}; - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_ck.cuh b/onnxruntime/core/providers/rocm/math/softmax_ck.cuh deleted file mode 100644 index f87b436d04a17..0000000000000 --- a/onnxruntime/core/providers/rocm/math/softmax_ck.cuh +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#ifdef USE_COMPOSABLE_KERNEL -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif // USE_COMPOSABLE_KERNEL - -#include "core/providers/rocm/math/softmax_common.h" - -namespace onnxruntime { -namespace rocm { - -#ifdef USE_COMPOSABLE_KERNEL - -using Nop = ck::tensor_operation::element_wise::PassThrough; -constexpr int Rank = 4; -constexpr int NumReduceDim = 1; - -template -auto GetCKSoftmaxTypeStringAndOps() { - using InDataType = typename CKDataTypeAdaptor::type; - using OutDataType = typename CKDataTypeAdaptor::type; - using AccDataType = typename CKDataTypeAdaptor::type; - using DeviceSoftmax = ck::tensor_operation::device:: - DeviceSoftmax; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = onnxruntime::MakeString(impl->GetTypeString()); - auto invoker = impl->MakeInvokerPointer(); - - auto ck_softmax_op = [impl = std::move(impl), invoker = std::move(invoker)](const SoftmaxParams* params) -> Status { - double alpha{1.0f}; - double beta{0.0f}; - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->is_log_softmax, - impl->GetTypeString(), " does not support log softmax"); - - std::vector in_lengths{1, 1, params->batch_count, params->softmax_elements}; - std::vector in_strides{params->batch_count * params->input_stride, params->batch_count * params->input_stride, params->input_stride, 1}; - std::vector reduce_dims{3}; - - auto nop = Nop{}; - auto arg = impl->MakeArgumentPointer(in_lengths, in_strides, reduce_dims, alpha, beta, - params->input, params->output, nop, nop); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_softmax_op))); - } - return ret; -} -#endif // USE_COMPOSABLE_KERNEL - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_common.cc b/onnxruntime/core/providers/rocm/math/softmax_common.cc deleted file mode 100644 index 1cc36fe6d6cef..0000000000000 --- a/onnxruntime/core/providers/rocm/math/softmax_common.cc +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/math/softmax_common.h" - -namespace onnxruntime { -namespace rocm { - -Status SoftmaxForward(miopenHandle_t miopen_handle, const void* alpha, const miopenTensorDescriptor_t input_tensor, - const void* input_data, const void* beta, const miopenTensorDescriptor_t output_tensor, - void* output_data) { - MIOPEN_RETURN_IF_ERROR(miopenSoftmaxForward_V2(miopen_handle, alpha, input_tensor, input_data, beta, output_tensor, - output_data, MIOPEN_SOFTMAX_ACCURATE, MIOPEN_SOFTMAX_MODE_INSTANCE)); - return Status::OK(); -} - -Status SoftmaxBackward(miopenHandle_t miopen_handle, bool is_log_softmax, const void* alpha, - const miopenTensorDescriptor_t input_tensor, const void* output_data, - const void* output_grad_data, const void* beta, const miopenTensorDescriptor_t output_tensor, - void* input_grad_data) { - MIOPEN_RETURN_IF_ERROR(miopenSoftmaxBackward_V2( - miopen_handle, alpha, input_tensor, output_data, input_tensor, output_grad_data, beta, output_tensor, - input_grad_data, is_log_softmax ? MIOPEN_SOFTMAX_LOG : MIOPEN_SOFTMAX_ACCURATE, MIOPEN_SOFTMAX_MODE_INSTANCE)); - return Status::OK(); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_common.h b/onnxruntime/core/providers/rocm/math/softmax_common.h deleted file mode 100644 index 4a422b7bf9d7e..0000000000000 --- a/onnxruntime/core/providers/rocm/math/softmax_common.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/status.h" -#include "core/providers/rocm/miopen_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace rocm { - -template -struct SoftmaxParams : tunable::OpParams { - SoftmaxParams(tunable::RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, OutputT* output, const InputT* input, - int softmax_elements, int input_stride, int output_stride, int batch_count, bool is_log_softmax) - : OpParams(tuning_ctx, stream), output(output), input(input), softmax_elements(softmax_elements), input_stride(input_stride), output_stride(output_stride), batch_count(batch_count), is_log_softmax(is_log_softmax) {} - - std::string Signature() const override { - std::string sig = std::to_string(batch_count) + "_" + std::to_string(softmax_elements); - return sig; - } - - OutputT* output; - const InputT* input; - int softmax_elements; - int input_stride; - int output_stride; - int batch_count; - bool is_log_softmax; -}; - -Status SoftmaxForward(miopenHandle_t miopen_handle, const void* alpha, const miopenTensorDescriptor_t input_tensor, - const void* input_data, const void* beta, const miopenTensorDescriptor_t output_tensor, - void* output_data); - -Status SoftmaxBackward(miopenHandle_t miopen_handle, bool is_log_softmax, const void* alpha, - const miopenTensorDescriptor_t input_tensor, const void* output_data, - const void* output_grad_data, const void* beta, const miopenTensorDescriptor_t output_tensor, - void* input_grad_data); - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_impl.cu b/onnxruntime/core/providers/rocm/math/softmax_impl.cu deleted file mode 100644 index 316ba43b1f205..0000000000000 --- a/onnxruntime/core/providers/rocm/math/softmax_impl.cu +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Modifications Copyright (c) Microsoft. */ - -// The code below is mostly copied from Pytorch PersistentSoftmax.cuh -#include -#include "hip/hip_runtime.h" - -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/math/softmax.h" -#include "core/providers/rocm/math/softmax_common.h" -#include "core/providers/rocm/math/softmax_tunable_op.cuh" - -#include - -namespace onnxruntime { -namespace rocm { - -template -Status dispatch_warpwise_softmax_forward(Stream* stream, OutputT* dst, const InputT* src, int softmax_elements, - int softmax_elements_stride, int batch_count, RocmTuningContext* tuning_ctx) { - SoftmaxParams params(tuning_ctx, stream, dst, src, softmax_elements, softmax_elements_stride, - softmax_elements_stride, batch_count, IsLogSoftmax); - if (tuning_ctx != nullptr && tuning_ctx->IsTunableOpEnabled()) { - static SoftmaxTunableOp op; - return op(¶ms); - } - return SoftmaxWarpwiseStaticSelection(¶ms); -} - -#define SPECIALIZED_SOFTMAX_IMPL(InputT, OutputT, AccT) \ - template Status dispatch_warpwise_softmax_forward( \ - Stream* stream, OutputT * dst, const InputT* src, int softmax_elements, \ - int softmax_elements_stride, int batch_count, RocmTuningContext* tuning_ctx); \ - template Status dispatch_warpwise_softmax_forward( \ - Stream* stream, OutputT * dst, const InputT* src, int softmax_elements, \ - int softmax_elements_stride, int batch_count, RocmTuningContext* tuning_ctx); - -SPECIALIZED_SOFTMAX_IMPL(float, float, float) -SPECIALIZED_SOFTMAX_IMPL(half, half, float) -SPECIALIZED_SOFTMAX_IMPL(half, float, float) -SPECIALIZED_SOFTMAX_IMPL(double, double, double) -SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float) - -template -Status dispatch_blockwise_softmax_forward(Stream* stream, OutputT* output, - const InputT* input, int softmax_elements, - int input_stride, int output_stride, - int batch_count, RocmTuningContext* tuning_ctx) { - SoftmaxParams params(tuning_ctx, stream, output, input, softmax_elements, input_stride, - output_stride, batch_count, IsLogSoftmax); - if (tuning_ctx != nullptr && tuning_ctx->IsTunableOpEnabled()) { - static SoftmaxTunableOp op; - return op(¶ms); - } - return SoftmaxBlockwiseStaticSelection(¶ms); -} - -#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(InputT, OutputT, AccT) \ - template Status dispatch_blockwise_softmax_forward( \ - Stream* stream, OutputT * output, const InputT* input, int softmax_elements, \ - int input_stride, int output_stride, int batch_count, RocmTuningContext* tuning_ctx); \ - template Status dispatch_blockwise_softmax_forward( \ - Stream* stream, OutputT * output, const InputT* input, int softmax_elements, \ - int input_stride, int output_stride, int batch_count, RocmTuningContext* tuning_ctx); - -SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float) -SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, half, float) -SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, float, float) -SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(double, double, double) -SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float) -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_triton.cuh b/onnxruntime/core/providers/rocm/math/softmax_triton.cuh deleted file mode 100644 index cc0e0d70056cc..0000000000000 --- a/onnxruntime/core/providers/rocm/math/softmax_triton.cuh +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "core/providers/rocm/math/softmax_common.h" -#include "core/providers/rocm/triton_kernel.h" - -namespace onnxruntime { -namespace rocm { - -#ifdef USE_TRITON_KERNEL - -namespace { - -template -std::string GetSoftmaxTritonGroupName() { - std::string ret = "softmax_"; - ret += GetDataTypeName(); - return ret; -} - -} // end of namespace - -template -auto GetSoftmaxTritonOps() { - std::vector>>> ret; - auto group_name = GetSoftmaxTritonGroupName(); - auto* kernel_list = GetOrtTritonKernelByGroup(group_name); - if (kernel_list == nullptr) { - return ret; - } - - for (auto i : *kernel_list) { - // check params match - auto* metadata = GetOrtTritonKernelMetadata(i); - auto block_size = -1; - const std::string block_name = "BLOCK_SIZE"; - if (metadata->constants.count(block_name) != 0) { - block_size = metadata->constants.at(block_name); - } - auto impl = [i, block_size](const SoftmaxParams* params) -> Status { - if (params->is_log_softmax) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(true, "log_softmax is not supported."); - } - if (block_size < params->softmax_elements) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(true, "BLOCK_SIZE (", block_size, ") is not supported."); - } - // construct args for launch kernel - struct { - void* out; - const void* in; - int in_stride; - int out_stride; - int n_cols; - } args = {(void*)params->output, (const void*)params->input, params->input_stride, params->output_stride, params->softmax_elements}; - - // grid dim is (batch_count, 1, 1) - return LaunchTritonKernel(params->StreamHandle(), i, params->batch_count, 1, 1, &args, sizeof(args)); - }; - ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); - } - return ret; -} - -#endif // USE_TRITON_KERNEL - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_triton.py b/onnxruntime/core/providers/rocm/math/softmax_triton.py deleted file mode 100644 index f4c9b6459082b..0000000000000 --- a/onnxruntime/core/providers/rocm/math/softmax_triton.py +++ /dev/null @@ -1,65 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import triton -import triton.language as tl - - -@triton.jit -def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): - # The rows of the softmax are independent, so we parallelize across those - row_idx = tl.program_id(0) - # The stride represents how much we need to increase the pointer to advance 1 row - row_start_ptr = input_ptr + row_idx * input_row_stride - # The block size is the next power of two greater than n_cols, so we can fit each - # row in a single block - col_offsets = tl.arange(0, BLOCK_SIZE) - input_ptrs = row_start_ptr + col_offsets - # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols - row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf")) - row_f32 = row.to(tl.float32) - # Subtract maximum for numerical stability - row_minus_max = row_f32 - tl.max(row_f32, axis=0) - # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - # Write back output to DRAM - output_row_start_ptr = output_ptr + row_idx * output_row_stride - output_ptrs = output_row_start_ptr + col_offsets - tl.store(output_ptrs, softmax_output.to(row.dtype), mask=col_offsets < n_cols) - - -# function_table = {'name': name, 'func': func, 'sig'=sig, kwargs={}}, - -dtypes = ["fp32", "fp16"] -blocks = [1024, 2048, 4096, 8192, 16384] -name_pattern = "softmax_{}_{}" -sig_pattern = "*{},*{},i32,i32,i32" -group_pattern = "softmax_{}" - - -def get_function_table(): - func_table = [] - - def get_num_warps(block_size): - num_warps = 4 - if block_size >= 2048: - num_warps = 8 - if block_size >= 4096: - num_warps = 16 - return num_warps - - for dtype in dtypes: - for b in blocks: - name = name_pattern.format(dtype, b) - group = group_pattern.format(dtype) - sig = sig_pattern.format(dtype, dtype) - num_warps = get_num_warps(b) - kwargs = {"num_warps": num_warps, "constants": {"BLOCK_SIZE": b}} - func_desc = {"name": name, "group": group, "func": softmax_kernel, "sig": sig, "kwargs": kwargs} - func_table.append(func_desc) - - return func_table diff --git a/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh b/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh deleted file mode 100644 index 06ee9f38f62ef..0000000000000 --- a/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/math/softmax_ck.cuh" -#include "core/providers/rocm/math/softmax_common.h" -#include "core/providers/rocm/math/softmax_warpwise_impl.cuh" -#include "core/providers/rocm/math/softmax_blockwise_impl.cuh" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "core/providers/rocm/math/softmax_triton.cuh" - -namespace onnxruntime { -namespace rocm { - -template -Status SoftmaxBlockwiseOp(const SoftmaxParams* params) { - dim3 grid(params->batch_count); - dim3 block = SoftMax_getBlockSize(VecSize, params->softmax_elements); - if (params->is_log_softmax) { - softmax_block_forward - <<StreamHandle()>>>( - params->output, const_cast(params->input), - params->softmax_elements, params->input_stride, - params->output_stride); - } else { - softmax_block_forward - <<StreamHandle()>>>( - params->output, const_cast(params->input), - params->softmax_elements, params->input_stride, - params->output_stride); - } - return HIP_CALL(hipGetLastError()); -} - -template -Status SoftmaxWarpwiseStaticSelection(const SoftmaxParams* params) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !(params->input_stride <= 1024 && params->input_stride * sizeof(InputT) <= 4096)); - if (params->softmax_elements == 0) { - return Status::OK(); - } else { - int log2_elements = log2_ceil(params->softmax_elements); - const int next_power_of_two = 1 << log2_elements; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < GPU_WARP_SIZE_HOST) ? next_power_of_two : GPU_WARP_SIZE_HOST; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = 1; - // use 256 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 256; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (params->batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { -#define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \ - case L2E: \ - softmax_warp_forward \ - <<StreamHandle()>>>( \ - params->output, params->input, params->batch_count, \ - params->input_stride, params->softmax_elements, \ - params->is_log_softmax); \ - break; - LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 - LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 - LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 - LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 - LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 - LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 - LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 - LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 - LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 - LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 - LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024 - default: - break; - } - } - return HIP_CALL(hipGetLastError()); -} - -template -Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams* params) { - dim3 grid(params->batch_count); - constexpr int ILP = sizeof(float4) / sizeof(InputT); - dim3 block = SoftMax_getBlockSize(ILP, params->softmax_elements); - if (params->is_log_softmax) { - softmax_block_forward - <<StreamHandle()>>>( - params->output, const_cast(params->input), - params->softmax_elements, params->input_stride, - params->output_stride); - } else { - softmax_block_forward - <<StreamHandle()>>>( - params->output, const_cast(params->input), - params->softmax_elements, params->input_stride, - params->output_stride); - } - return HIP_CALL(hipGetLastError()); -} - -template -Status SoftmaxStaticSelection(const SoftmaxParams* params) { - auto status = SoftmaxWarpwiseStaticSelection(params); - if (!status.IsOK()) { - status = SoftmaxBlockwiseStaticSelection(params); - } - return status; -} - -template -class SoftmaxTunableOp : public tunable::TunableOp> { - public: - SoftmaxTunableOp() { - this->RegisterOp(SoftmaxStaticSelection); - this->RegisterOp(SoftmaxWarpwiseStaticSelection); - this->RegisterOp(SoftmaxBlockwiseStaticSelection); - this->RegisterOp(SoftmaxBlockwiseOp); - this->RegisterOp(SoftmaxBlockwiseOp); - this->RegisterOp(SoftmaxBlockwiseOp); - this->RegisterOp(SoftmaxBlockwiseOp); - this->RegisterOp(SoftmaxBlockwiseOp); - -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKSoftmaxTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif // USE_COMPOSABLE_KERNEL - -#ifdef USE_TRITON_KERNEL - for (auto&& [_, op] : GetSoftmaxTritonOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - // this->RegisterOp(SoftmaxTritonOp); -#endif - } -}; - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_warpwise_impl.cuh b/onnxruntime/core/providers/rocm/math/softmax_warpwise_impl.cuh deleted file mode 100644 index f30bb970e0177..0000000000000 --- a/onnxruntime/core/providers/rocm/math/softmax_warpwise_impl.cuh +++ /dev/null @@ -1,167 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// The code below is mostly copied from Pytorch PersistentSoftmax.cuh - -#pragma once -#include "hip/hip_runtime.h" -#include "core/providers/rocm/cu_inc/common.cuh" - -namespace onnxruntime { -namespace rocm { - -inline int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -// The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension. -// Each sample contains element_count scalar elements. element_count can be any integer value <= 1024. -// The template arguments have the following meaning: -// One "WARP" works on one "BATCH". One "BATCH" contains "WARP_BATCH" samples. -// WARP_BATCH is equal to 1 when element_count is large, and > 1 when element_count is small. -// A "WARP" contains "GPU_WARP_SIZE" threads, these treads are guaranteed to belong to the same warp. -// This is important because it means only __shfl_ instructions are required for reductions. -// Note that this means WARP_SIZE must be a power of two and <= architecture warp size. -// ROCM warp size is 64 for all existing GPU architecures, but there is no guarantee this will not change for future arch. -// is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed. -// The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t. -// This allows SoftMax to be fused with a cast immediately following the SoftMax. -// For instance: -// input_t=half, acc_t=float, output_t=half => read half tensor, float accumulators, write half tensor. -// input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor. -// input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor. - -template -__global__ void softmax_warp_forward(output_t* dst, const input_t* src, int batch_size, int stride, int element_count, bool is_log_softmax) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = 1; - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * stride + local_idx; - dst += first_batch * stride + local_idx; - - // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop, - // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep - // the nested loops. - // This should have no impact on performance because the loops are unrolled anyway. - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - elements[i][it] = src[i * element_count + it * WARP_SIZE]; - } else { - elements[i][it] = -std::numeric_limits::infinity(); - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH]{0.0f}; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (is_log_softmax) { - sum[i] += expf((float)(elements[i][it] - max_value[i])); - } else { - elements[i][it] = expf((float)(elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - } - warp_reduce(sum); - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - if (is_log_softmax) sum[i] = max_value[i] + logf((float)(sum[i])); -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * WARP_SIZE; - if (element_index < element_count) { - if (is_log_softmax) { - dst[i * element_count + it * WARP_SIZE] = elements[i][it] - sum[i]; - } else { - dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i]; - } - } else { - break; - } - } - } -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/miopen_common.cc b/onnxruntime/core/providers/rocm/miopen_common.cc deleted file mode 100644 index 6b08d392069a1..0000000000000 --- a/onnxruntime/core/providers/rocm/miopen_common.cc +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "miopen_common.h" -#include -#include "core/providers/cpu/tensor/utils.h" -#include "core/providers/rocm/shared_inc/rocm_call.h" - -namespace onnxruntime { -namespace rocm { - -namespace { -std::string layoutTypeToString(miopenTensorLayout_t layout) { - if (layout == MIOPEN_NCHW_LAYOUT) { - return "NCHW"; - } else if (layout == MIOPEN_NHWC_LAYOUT) { - return "NHWC"; - } else { - ORT_THROW("Currently, ORT only supports two MIOpen layout: MIOPEN_NCHW_LAYOUT and MIOPEN_NHWC_LAYOUT."); - } -} - -// This functions was modified from https://github.com/ROCmSoftwarePlatform/MIOpen/src/include/miopen/tensor_layout.hpp -template -void tensorLayoutToStrides(const InlinedVector& len, - miopenTensorLayout_t len_tensor_layout, - miopenTensorLayout_t tensor_layout, - InlinedVector& strides) { - std::string len_layout = layoutTypeToString(len_tensor_layout); - std::string layout = layoutTypeToString(tensor_layout); - // Bind the layout and the dimension lengths together into a map. - std::map dim_to_len; - std::transform(len.begin(), - len.end(), - len_layout.begin(), - std::inserter(dim_to_len, dim_to_len.end()), - [](T l, char dim) { return std::make_pair(dim, l); }); - - // Now construct the strides according to layout by multiply the - // dimension lengths together. - std::transform(len_layout.begin(), - len_layout.end(), - strides.begin(), - [&layout, &dim_to_len](char cur_layout_char) { - auto pos = layout.find(cur_layout_char); - if (pos == std::string::npos) { - ORT_THROW(std::string("mismatched layout string - ").append(layout)); - } - return std::accumulate(layout.begin() + pos + 1, - layout.end(), - 1, - [&dim_to_len](T accumulator, char l) { - return accumulator * dim_to_len[l]; - }); - }); -} -} // namespace - -MiopenTensor::MiopenTensor() - : tensor_(nullptr) { -} - -MiopenTensor::~MiopenTensor() { - if (tensor_ != nullptr) { - miopenDestroyTensorDescriptor(tensor_); - tensor_ = nullptr; - } -} - -Status MiopenTensor::CreateTensorIfNeeded() { - if (!tensor_) - MIOPEN_RETURN_IF_ERROR(miopenCreateTensorDescriptor(&tensor_)); - return Status::OK(); -} - -Status MiopenTensor::Set(gsl::span input_dims, miopenDataType_t dataType, bool is_nhwc) { - if (is_nhwc) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "NHWC Tensor usage is not supported in AMD builds for now"); - } - - ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); - - int rank = gsl::narrow_cast(input_dims.size()); - TensorPitches pitches(input_dims); - InlinedVector dims(rank); - InlinedVector strides(rank); - for (int i = 0; i < rank; i++) { - dims[i] = gsl::narrow_cast(input_dims[i]); - strides[i] = gsl::narrow_cast(pitches[i]); - } - MIOPEN_RETURN_IF_ERROR(miopenSetTensorDescriptor(tensor_, dataType, static_cast(rank), dims.data(), strides.data())); - return Status::OK(); -} - -Status MiopenTensor::Set(miopenDataType_t dataType, miopenTensorLayout_t tensor_layout, int n, int c, int h, int w) { - ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); - - // miopenSetNdTensorDescriptorWithLayout doesn't support NHWC layout now. - // We use miopenSetTensorDescriptor with dims = [N, C, H, W], strides = [N*W*C, 1, W*C, C] for NHWC layout. - const int num_lens = 4; - InlinedVector dims = {n, c, h, w}; - InlinedVector strides(num_lens); - - miopenTensorLayout_t len_layout = MIOPEN_NCHW_LAYOUT; - tensorLayoutToStrides(dims, len_layout, tensor_layout, strides); - - MIOPEN_RETURN_IF_ERROR(miopenSetTensorDescriptor(tensor_, dataType, static_cast(num_lens), dims.data(), strides.data())); - return Status::OK(); -} - -Status MiopenTensor::Set(const MiopenTensor& x_desc, miopenBatchNormMode_t mode) { - ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); - MIOPEN_RETURN_IF_ERROR(miopenDeriveBNTensorDescriptor(tensor_, x_desc, mode)); - return Status::OK(); -} - -MiopenTensorDescriptor::MiopenTensorDescriptor() : desc_(nullptr) { - miopenCreateTensorDescriptor(&desc_); -} - -MiopenTensorDescriptor::~MiopenTensorDescriptor() { - if (desc_ != nullptr) { - miopenCreateTensorDescriptor(&desc_); - desc_ = nullptr; - } -} - -Status MiopenTensorDescriptor::Set(gsl::span filter_dims, miopenDataType_t data_type) { - if (!desc_) - MIOPEN_RETURN_IF_ERROR(miopenCreateTensorDescriptor(&desc_)); - - int rank = gsl::narrow_cast(filter_dims.size()); - InlinedVector w_dims(rank); - for (int i = 0; i < rank; i++) { - w_dims[i] = gsl::narrow_cast(filter_dims[i]); - } - - MIOPEN_RETURN_IF_ERROR(miopenSetTensorDescriptor(desc_, - data_type, - rank, - w_dims.data(), - nullptr)); - return Status::OK(); -} - -Status MiopenTensorDescriptor::Set(miopenDataType_t data_type, miopenTensorLayout_t tensor_layout, int k, int c, int h, int w) { - if (!desc_) - MIOPEN_RETURN_IF_ERROR(miopenCreateTensorDescriptor(&desc_)); - - // miopenSetNdTensorDescriptorWithLayout doesn't support NHWC layout now. - // We use miopenSetTensorDescriptor with dims = [N, C, H, W], strides = [N*W*C, 1, W*C, C] for NHWC layout. - const int num_lens = 4; - InlinedVector dims = {k, c, h, w}; - InlinedVector strides(num_lens); - - miopenTensorLayout_t len_layout = MIOPEN_NCHW_LAYOUT; - tensorLayoutToStrides(dims, len_layout, tensor_layout, strides); - - MIOPEN_RETURN_IF_ERROR(miopenSetTensorDescriptor(desc_, data_type, static_cast(num_lens), dims.data(), strides.data())); - return Status::OK(); -} - -template -miopenDataType_t MiopenTensor::GetDataType() { - ORT_THROW("miopen engine currently supports only single/half/int32/int8 precision data types."); -} - -#if ROCM_VERSION >= 50000 -template <> -miopenDataType_t MiopenTensor::GetDataType() { - return miopenDouble; -} -#endif - -template <> -miopenDataType_t MiopenTensor::GetDataType() { - return miopenFloat; -} - -template <> -miopenDataType_t MiopenTensor::GetDataType() { - return miopenHalf; -} - -template <> -miopenDataType_t MiopenTensor::GetDataType() { - ORT_THROW("miopen doesn't support BFloat16."); - return miopenFloat; -} - -template <> -miopenDataType_t MiopenTensor::GetDataType() { - return miopenInt32; -} - -template <> -miopenDataType_t MiopenTensor::GetDataType() { - return miopenInt8; -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/miopen_common.h b/onnxruntime/core/providers/rocm/miopen_common.h deleted file mode 100644 index eb4eb745b3692..0000000000000 --- a/onnxruntime/core/providers/rocm/miopen_common.h +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/rocm/rocm_common.h" - -#include - -const double MIOPEN_BN_MIN_EPSILON = 1e-5; - -namespace onnxruntime { -namespace rocm { - -#if MIOPEN_VERSION < 21800 -typedef enum { - miopenTensorNCHW = 0, - miopenTensorNHWC = 1, -} miopenTensorLayout_t; -#endif - -#define MIOPEN_CONVOLUTION_FWD_ALGO_COUNT 6 -#define MIOPEN_CONVOLUTION_BWD_FILTER_ALGO_COUNT 4 -#define MIOPEN_CONVOLUTION_BWD_DATA_ALGO_COUNT 6 -#define MIOPEN_NCHW_LAYOUT miopenTensorNCHW -#define MIOPEN_NHWC_LAYOUT miopenTensorNHWC - -class MiopenTensor final { - public: - MiopenTensor(); - ~MiopenTensor(); - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MiopenTensor); - - Status Set(gsl::span input_dims, miopenDataType_t dataType, bool is_nhwc = false); - Status Set(miopenDataType_t dataType, miopenTensorLayout_t tensor_layout, int n, int c, int h, int w); - Status Set(const MiopenTensor& x_desc, miopenBatchNormMode_t mode); - - operator miopenTensorDescriptor_t() const { return tensor_; } - - template - static miopenDataType_t GetDataType(); - - private: - Status CreateTensorIfNeeded(); - - miopenTensorDescriptor_t tensor_; -}; - -class MiopenTensorDescriptor final { - public: - MiopenTensorDescriptor(); - ~MiopenTensorDescriptor(); - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MiopenTensorDescriptor); - - Status Set(gsl::span filter_dims, miopenDataType_t data_type); - // Set 4D filter where k is output channels, c is input channels, h and w is rows and columns per filter. - Status Set(miopenDataType_t data_type, miopenTensorLayout_t tensor_layout, int k, int c, int h, int w); - - operator miopenTensorDescriptor_t() const { return desc_; } - - private: - miopenTensorDescriptor_t desc_; -}; - -template -struct Consts { - static const constexpr ElemType Zero{0}; - static const constexpr ElemType One{1}; -}; - -template <> -struct Consts { - static const constexpr float Zero{0}; - static const constexpr float One{1}; -}; - -template <> -struct Consts { - static const constexpr float Zero{0}; - static const constexpr float One{1}; -}; - -template -struct ReduceConsts { - static const constexpr ElemType Zero{0}; - static const constexpr ElemType One{1}; -}; - -#if ROCM_VERSION >= 40300 -// Up until ROCm 4.2 miopenReduceTensor() required alpha/beta to be the same data -// type as the input type. This differs from cudnnReduceTensor() and other -// MIOpen/cuDNN APIs where alpha/beta are float when input type is half (float16). -template <> -struct ReduceConsts { - static const constexpr float Zero{0}; - static const constexpr float One{1}; -}; - -template <> -struct ReduceConsts { - static const constexpr float Zero{0}; - static const constexpr float One{1}; -}; -#endif - -inline double ClampMiopenBatchNormEpsilon(double epsilon) { - if (epsilon < MIOPEN_BN_MIN_EPSILON) { - if (MIOPEN_BN_MIN_EPSILON - epsilon > FLT_EPSILON) - LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than MIOPEN_BN_MIN_EPSILON. Setting it to MIOPEN_BN_MIN_EPSILON"; - return MIOPEN_BN_MIN_EPSILON; - } - return epsilon; -} - -inline miopenStatus_t -BatchNormalizationForwardInferenceHelper(miopenHandle_t handle, - miopenBatchNormMode_t mode, - const void* alpha, - const void* beta, - const miopenTensorDescriptor_t xDesc, - const void* x, - const miopenTensorDescriptor_t yDesc, - void* y, - const miopenTensorDescriptor_t bnScaleBiasMeanVarDesc, - const void* bnScale, - const void* bnBias, - const void* estimatedMean, - const void* estimatedVariance, - double epsilon) { - return miopenBatchNormalizationForwardInference(handle, - mode, - const_cast(alpha), - const_cast(beta), - xDesc, - x, - yDesc, - y, - bnScaleBiasMeanVarDesc, - const_cast(bnScale), - const_cast(bnBias), - const_cast(estimatedMean), - const_cast(estimatedVariance), - epsilon); -} - -inline miopenStatus_t -BatchNormalizationForwardTrainingHelper(miopenHandle_t handle, - miopenBatchNormMode_t mode, - const void* alpha, - const void* beta, - const miopenTensorDescriptor_t xDesc, - const void* x, - const miopenTensorDescriptor_t yDesc, - void* y, - const miopenTensorDescriptor_t bnScaleBiasMeanVarDesc, - const void* bnScale, - const void* bnBias, - double exponentialAverageFactor, - void* resultRunningMean, - void* resultRunningVariance, - double epsilon, - void* resultSaveMean, - void* resultSaveInvVariance) { - return miopenBatchNormalizationForwardTraining(handle, - mode, - const_cast(alpha), - const_cast(beta), - xDesc, - x, - yDesc, - y, - bnScaleBiasMeanVarDesc, - const_cast(bnScale), - const_cast(bnBias), - exponentialAverageFactor, - resultRunningMean, - resultRunningVariance, - epsilon, - resultSaveMean, - resultSaveInvVariance); -} - -inline miopenStatus_t -LRNCrossChannelForwardHelper(miopenHandle_t handle, - miopenLRNDescriptor_t normDesc, - miopenLRNMode_t lrnMode, - const void* alpha, - const miopenTensorDescriptor_t xDesc, - const void* x, - const void* beta, - const miopenTensorDescriptor_t yDesc, - void* y) { - if (lrnMode != miopenLRNCrossChannel) { - LOGS_DEFAULT(ERROR) << __func__ << " must be called with lrnMode == miopenLRNCrossChannel"; - return miopenStatusBadParm; - } - return miopenLRNForward(handle, normDesc, alpha, xDesc, x, beta, yDesc, y, false, nullptr); -} - -inline miopenStatus_t -SetLRNDescriptorHelper(miopenLRNDescriptor_t normDesc, - unsigned lrnN, - double lrnAlpha, - double lrnBeta, - double lrnK) { - return miopenSetLRNDescriptor(normDesc, miopenLRNCrossChannel, lrnN, lrnAlpha, lrnBeta, lrnK); -} - -inline miopenStatus_t -PoolingForwardHelper(miopenHandle_t handle, - const miopenPoolingDescriptor_t poolDesc, - const void* alpha, - const miopenTensorDescriptor_t xDesc, - const void* x, - const void* beta, - const miopenTensorDescriptor_t yDesc, - void* y) { - return miopenPoolingForward(handle, poolDesc, alpha, xDesc, x, beta, yDesc, y, false, nullptr, 0); -} - -inline miopenStatus_t -SetPoolingNdDescriptorHelper(miopenPoolingDescriptor_t poolDesc, - const miopenPoolingMode_t mode, - miopenNanPropagation_t /* unavailable */, - int nbDims, - int* windowDimA, - int* padA, - int* stridesA) { - return miopenSetNdPoolingDescriptor(poolDesc, mode, nbDims, windowDimA, padA, stridesA); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc deleted file mode 100644 index f99885634b6c7..0000000000000 --- a/onnxruntime/core/providers/rocm/nn/conv.cc +++ /dev/null @@ -1,423 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/nn/conv.h" -#include "core/common/span_utils.h" -#include "core/providers/rocm/nn/conv_impl.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tensor/slice.h" - -namespace onnxruntime { -namespace rocm { - -// Op Set 11 for Conv only update document to clearify default dilations and strides value. -// which are already convered by op set 11 cpu version, so simply add declaration. -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Conv, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Conv, \ - kOnnxDomain, \ - 11, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); - -REGISTER_KERNEL_TYPED(float) -// not yet supported in MIOpen -// REGISTER_KERNEL_TYPED(double) -REGISTER_KERNEL_TYPED(MLFloat16) - -template -const miopenConvFwdAlgorithm_t Conv::kAllAlgos[] = { - miopenConvolutionFwdAlgoGEMM, - miopenConvolutionFwdAlgoDirect, - miopenConvolutionFwdAlgoFFT, - miopenConvolutionFwdAlgoWinograd, - miopenConvolutionFwdAlgoImplicitGEMM}; - -miopenStatus_t GetWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, miopenConvFwdAlgorithm_t algo, size_t* sz) { - return miopenConvolutionForwardGetWorkSpaceSize(handle, s.w_desc, s.x_tensor, s.conv_desc, s.y_tensor, sz); -} - -size_t GetMaxWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, - const miopenConvFwdAlgorithm_t* algo, int n_algo, int device_id = 0) { - // TODO: get maximum available size from memory arena - size_t free, total; - onnxruntime::rocm::hipMemGetInfoAlt(device_id, &free, &total); - // Assuming 10% of fragmentation - free = static_cast(static_cast(free) * 0.9); - size_t max_ws_size = 0; - for (int i = 0; i < n_algo; i++) { - miopenStatus_t err; - size_t sz; - err = GetWorkspaceSize(handle, s, algo[i], &sz); - if (miopenStatusSuccess != err || sz == 0 || sz < max_ws_size || sz > free) continue; - max_ws_size = sz; - } - return max_ws_size; -} - -Status SliceOutUnwantedOutputSection(hipStream_t stream, - const void* input_data, gsl::span input_dims, - void* output_data, - const gsl::span& output_dims, - const gsl::span& starts, - const gsl::span& ends, - const gsl::span& axes, - size_t element_size) { - SliceOp::PrepareForComputeMetadata compute_metadata(input_dims); - - ORT_THROW_IF_ERROR(SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata)); - - // As a sanity check, ensure that the slice operator's output shape matches with the expected output shape - ORT_ENFORCE(SpanEq(gsl::make_span(compute_metadata.output_dims_), output_dims)); - - return SliceRocm::Impl(stream, input_data, input_dims, output_data, compute_metadata, element_size); -} - -template -Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { - // set X - const Tensor* X = context->Input(0); - const TensorShape& x_shape = X->Shape(); - const auto x_dims = x_shape.AsShapeVector(); - s_.x_data = reinterpret_cast(X->Data()); - s_.element_size = X->DataType()->Size(); - // set W - const Tensor* W = context->Input(1); - const TensorShape& w_shape = W->Shape(); - auto w_dims = w_shape.AsShapeVector(); - s_.w_data = reinterpret_cast(W->Data()); - - // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. - constexpr bool channels_last = NHWC; - if (channels_last && (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); - } - - // set B - if (context->InputCount() >= 3) { - const Tensor* B = context->Input(2); - s_.b_data = reinterpret_cast(B->Data()); - } else { - s_.b_data = nullptr; - } - // set Z - if (context->InputCount() >= 4) { - const Tensor* Z = context->Input(3); - ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(), MiopenTensor::GetDataType())); - s_.z_data = reinterpret_cast(Z->Data()); - } else { - s_.z_data = nullptr; - } - bool input_dims_changed = (s_.last_x_dims != x_dims); - bool w_dims_changed = (s_.last_w_dims != w_dims); - if (input_dims_changed || w_dims_changed) { - if (input_dims_changed) - s_.last_x_dims = gsl::make_span(x_dims); - - if (w_dims_changed) { - s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_fwd_results.clear(); - } - - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last)); - - TensorShapeVector kernel_shape; - ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape, channels_last)); - - const size_t kernel_rank = kernel_shape.size(); - - ConvPadVector pads(conv_attrs_.pads); - if (pads.empty()) { - pads.resize(kernel_rank * 2, 0); - } - TensorShapeVector dilations(conv_attrs_.dilations); - if (dilations.empty()) { - dilations.resize(kernel_rank, 1); - } - TensorShapeVector strides(conv_attrs_.strides); - if (strides.empty()) { - strides.resize(kernel_rank, 1); - } - - TensorShapeVector y_dims; - y_dims.reserve(2 + kernel_rank); // add 2 to account for 'N' and 'C' - - const int64_t N = X->Shape()[0]; - const int64_t M = W->Shape()[0]; - if (channels_last) { - y_dims.push_back(N); - } else { - y_dims.insert(y_dims.begin(), {N, M}); - } - - bool post_slicing_required = false; - TensorShapeVector slice_starts; - slice_starts.reserve(kernel_rank); - - TensorShapeVector slice_ends; - slice_ends.reserve(kernel_rank); - - TensorShapeVector slice_axes; - slice_axes.reserve(kernel_rank); - - constexpr size_t spatial_dim_start = channels_last ? 1 : 2; - const size_t spatial_dim_end = spatial_dim_start + kernel_rank; - TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); - - TensorShapeVector y_dims_with_adjusted_pads(y_dims); - ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(spatial_shape, kernel_shape, - strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, - post_slicing_required, slice_starts, slice_ends, slice_axes, - channels_last)); - if (channels_last) { - y_dims.push_back(M); - y_dims_with_adjusted_pads.push_back(M); - } - - ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size()); - s_.y_dims = gsl::make_span(y_dims); - s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads; - s_.post_slicing_required = post_slicing_required; - s_.slice_starts = slice_starts; - s_.slice_ends = slice_ends; - s_.slice_axes = slice_axes; - - s_.Y = context->Output(0, TensorShape(s_.y_dims)); - if (post_slicing_required) { - // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. - s_.memory_for_miopen_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); - s_.y_data = reinterpret_cast(s_.memory_for_miopen_conv_results.get()); - } else { - // No post slicing needed. Fill the output tensor's buffer directly. - s_.y_data = reinterpret_cast(s_.Y->MutableData()); - } - - TensorShapeVector x_dims_miopen{x_dims.begin(), x_dims.end()}; - TensorShapeVector y_dims_miopen = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; - if (kernel_rank < 2) { - // TODO: Remove asym padding correction. - x_dims_miopen.push_back(1); - y_dims_miopen.push_back(1); - w_dims.push_back(1); - pads.insert(pads.begin() + kernel_rank, 0); - pads.insert(pads.end(), 0); - kernel_shape.push_back(1); - strides.push_back(1); - dilations.push_back(1); - } - - if (w_dims_changed) { - if (!channels_last) { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, MiopenTensor::GetDataType())); - } else { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(MiopenTensor::GetDataType(), - miopenTensorNHWC, - w_dims[0], - w_dims[3], - w_dims[1], - w_dims[2])); - } - } - - // We must delay returning early until here so that the weight dims have been cached properly - if (s_.Y->Shape().Size() == 0) { - return Status::OK(); - } - - if (channels_last) { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(MiopenTensor::GetDataType(), - miopenTensorNHWC, - x_dims_miopen[0], - x_dims_miopen[3], - x_dims_miopen[1], - x_dims_miopen[2])); - - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(MiopenTensor::GetDataType(), - miopenTensorNHWC, - y_dims_miopen[0], - y_dims_miopen[3], - y_dims_miopen[1], - y_dims_miopen[2])); - } else { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_miopen, MiopenTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_miopen, MiopenTensor::GetDataType())); - } - - ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, - gsl::narrow_cast(conv_attrs_.group), - miopenConvolution, MiopenTensor::GetDataType())); - - if (context->InputCount() >= 3) { - const Tensor* B = context->Input(2); - const auto& b_shape = B->Shape(); - ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); - TensorShapeVector b_dims(2 + kernel_shape.size(), 1); - b_dims[1] = b_shape[0]; - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, MiopenTensor::GetDataType())); - } else if (bias_expected) { - TensorShapeVector b_dims(2 + kernel_shape.size(), 1); - b_dims[1] = w_dims[0]; - auto malloc_size = b_dims[1] * sizeof(HipT); - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, MiopenTensor::GetDataType())); - if (s_.b_zero) { - HIP_CALL_THROW(hipFree(s_.b_zero)); - s_.b_zero = nullptr; - } - HIP_CALL_THROW(hipMalloc(&s_.b_zero, malloc_size)); - HIP_CALL_THROW(hipMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); - } - - if (!s_.cached_benchmark_fwd_results.contains(x_dims_miopen)) { - miopenConvAlgoPerf_t perf; - int algo_count = 1; - const ROCMExecutionProvider* rocm_ep = static_cast(this->Info().GetExecutionProvider()); - static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT; - size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos, rocm_ep->GetDeviceId()) - : AlgoSearchWorkspaceSize; - IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); - MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm( - GetMiopenHandle(context), - s_.x_tensor, - s_.x_data, - s_.w_desc, - s_.w_data, - s_.conv_desc, - s_.y_tensor, - s_.y_data, - 1, // requestedAlgoCount - &algo_count, // returnedAlgoCount - &perf, - algo_search_workspace.get(), - max_ws_size, - false)); // Do not do exhaustive algo search. - s_.cached_benchmark_fwd_results.insert(x_dims_miopen, {perf.fwd_algo, perf.memory}); - } - const auto& perf = s_.cached_benchmark_fwd_results.at(x_dims_miopen); - s_.fwd_algo = perf.fwd_algo; - s_.workspace_bytes = perf.memory; - } else { - // set Y - s_.Y = context->Output(0, TensorShape(s_.y_dims)); - if (s_.Y->Shape().Size() == 0) { - return Status::OK(); - } - if (s_.post_slicing_required) { - s_.memory_for_miopen_conv_results = GetScratchBuffer(TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); - s_.y_data = reinterpret_cast(s_.memory_for_miopen_conv_results.get()); - } else { - s_.y_data = reinterpret_cast(s_.Y->MutableData()); - } - } - return Status::OK(); -} - -template -Status Conv::ComputeInternal(OpKernelContext* context) const { - std::lock_guard lock(s_.mutex); - ORT_RETURN_IF_ERROR(UpdateState(context)); - if (s_.Y->Shape().Size() == 0) { - return Status::OK(); - } - const auto alpha = Consts::One; - const auto beta = Consts::Zero; - IAllocatorUniquePtr workspace = GetWorkSpace(context->GetComputeStream()); - auto miopen_handle = GetMiopenHandle(context); - MIOPEN_RETURN_IF_ERROR(miopenConvolutionForward(miopen_handle, - &alpha, - s_.x_tensor, - s_.x_data, - s_.w_desc, - s_.w_data, - s_.conv_desc, - s_.fwd_algo, - &beta, - s_.y_tensor, - s_.y_data, - workspace.get(), - s_.workspace_bytes)); - - constexpr bool channels_last = NHWC; - if (nullptr != s_.b_data && !channels_last) { - MIOPEN_RETURN_IF_ERROR(miopenConvolutionForwardBias(miopen_handle, &alpha, s_.b_tensor, s_.b_data, - &beta, s_.y_tensor, s_.y_data)); - } - // To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions - // This may have lead to extra results that are unnecessary and hence we slice that off here - if (s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, gsl::make_span(s_.y_dims_with_adjusted_pads), - s_.Y->MutableDataRaw(), s_.y_dims.GetDims(), s_.slice_starts, - s_.slice_ends, s_.slice_axes, s_.element_size)); - } - if (nullptr != s_.b_data && channels_last) { - const Tensor* B = context->Input(2); - const auto& b_shape = B->Shape(); - - ConvBiasImpl(Stream(context), reinterpret_cast(s_.Y->MutableDataRaw()), - reinterpret_cast(B->Data()), - reinterpret_cast(s_.Y->MutableDataRaw()), b_shape[0], s_.Y->Shape().Size()); - } - return Status::OK(); -} - -MiopenConvolutionDescriptor::MiopenConvolutionDescriptor() : desc_(nullptr) { -} - -MiopenConvolutionDescriptor::~MiopenConvolutionDescriptor() { - if (desc_ != nullptr) { - miopenDestroyConvolutionDescriptor(desc_); - desc_ = nullptr; - } -} - -Status MiopenConvolutionDescriptor::Set( - size_t rank, - const gsl::span& pads, - const gsl::span& strides, - const gsl::span& dilations, - int groups, - miopenConvolutionMode_t mode, - miopenDataType_t data_type) { - if (!desc_) - MIOPEN_RETURN_IF_ERROR(miopenCreateConvolutionDescriptor(&desc_)); - - InlinedVector pad_dims(rank); - InlinedVector stride_dims(rank); - InlinedVector dilation_dims(rank); - for (size_t i = 0; i < rank; i++) { - pad_dims[i] = gsl::narrow_cast(pads[i]); - stride_dims[i] = gsl::narrow_cast(strides[i]); - dilation_dims[i] = gsl::narrow_cast(dilations[i]); - } - - MIOPEN_RETURN_IF_ERROR(miopenInitConvolutionNdDescriptor( - desc_, - gsl::narrow_cast(rank), - pad_dims.data(), - stride_dims.data(), - dilation_dims.data(), - mode)); - - MIOPEN_RETURN_IF_ERROR(miopenSetConvolutionGroupCount(desc_, groups)); - - return Status::OK(); -} - -#ifndef DISABLE_CONTRIB_OPS -// template instantiation for NhwcConv -template class Conv; -template class Conv; -#endif - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h deleted file mode 100644 index e6ebb5a380d3f..0000000000000 --- a/onnxruntime/core/providers/rocm/nn/conv.h +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "core/providers/rocm/rocm_kernel.h" -#include "core/providers/rocm/miopen_common.h" -#include "core/providers/cpu/nn/conv_attributes.h" -#include - -namespace onnxruntime { - -using ConvPadVector = ConvAttributes::ConvPadVector; - -namespace rocm { - -class MiopenConvolutionDescriptor final { - public: - MiopenConvolutionDescriptor(); - ~MiopenConvolutionDescriptor(); - - Status Set(size_t rank, - const gsl::span& pads, - const gsl::span& strides, - const gsl::span& dilations, - int groups, - miopenConvolutionMode_t mode, - miopenDataType_t data_type); - - operator miopenConvolutionDescriptor_t() const { return desc_; } - - private: - miopenConvolutionDescriptor_t desc_; -}; - -struct vector_hash { - std::size_t operator()(const TensorShapeVector& values) const { - std::size_t seed = values.size(); - for (auto& val : values) - seed ^= std::hash()(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - return seed; - } -}; - -template , - typename KeyEqual = std::equal_to, - typename ListAllocator = std::allocator> -class lru_unordered_map { - public: - lru_unordered_map(size_t max_size) : max_size_(max_size) {} - - void insert(const Key& key, const T& value) { - auto it = items_.find(key); - if (it != items_.end()) { - it->second.value = value; - move_to_front(it->second.lru_iterator); - return; - } - - while (size() + 1 > max_size_) { - items_.erase(lru_list_.back()); - lru_list_.pop_back(); - } - - lru_list_.emplace_front(key); - items_.emplace(key, value_type{value, lru_list_.begin()}); - } - - T& at(const Key& key) { - auto it = items_.find(key); - if (it == items_.end()) { - throw std::out_of_range("There is no such key in cache"); - } - move_to_front(it->second.lru_iterator); - return it->second.value; - } - - bool contains(const Key& key) const { - return items_.find(key) != items_.end(); - } - - size_t size() const { - return items_.size(); - } - - void clear() { - items_.clear(); - lru_list_.clear(); - } - - private: - using list_type = std::list; - using iterator_type = typename list_type::iterator; - struct value_type { - T value; - iterator_type lru_iterator; - }; - using MapAllocator = std::allocator>; - - void move_to_front(iterator_type it) { - lru_list_.splice(lru_list_.begin(), lru_list_, it); - } - - size_t max_size_; - std::unordered_map items_; - list_type lru_list_; -}; - -// cached miopen descriptors -constexpr size_t MAX_CACHED_ALGO_PERF_RESULTS = 10000; - -template -struct MiopenConvState { - // if x/w dims changed, update algo and miopenTensors - TensorShape last_x_dims; - TensorShape last_w_dims; - - // these would be recomputed if x/w dims change - TensorShape y_dims; - TensorShapeVector y_dims_with_adjusted_pads; - size_t workspace_bytes; - decltype(AlgoPerfType().bwd_data_algo) bwd_data_algo; - decltype(AlgoPerfType().fwd_algo) fwd_algo; - MiopenTensor x_tensor; - const void* x_data = nullptr; - size_t element_size = 0; - MiopenTensorDescriptor w_desc; - const void* w_data = nullptr; - MiopenTensor b_tensor; - const void* b_data = nullptr; - void* b_zero = nullptr; - MiopenTensor y_tensor; - Tensor* Y = nullptr; - void* y_data = nullptr; - MiopenTensor z_tensor; - const void* z_data = nullptr; - MiopenConvolutionDescriptor conv_desc; - - struct PerfFwdResultParams { - decltype(AlgoPerfType().fwd_algo) fwd_algo; - decltype(AlgoPerfType().memory) memory; - }; - - struct PerfBwdResultParams { - decltype(AlgoPerfType().bwd_data_algo) bwd_data_algo; - decltype(AlgoPerfType().memory) memory; - }; - - lru_unordered_map cached_benchmark_fwd_results{MAX_CACHED_ALGO_PERF_RESULTS}; - lru_unordered_map cached_benchmark_bwd_results{MAX_CACHED_ALGO_PERF_RESULTS}; - - // Some properties needed to support asymmetric padded Conv nodes - bool post_slicing_required; - TensorShapeVector slice_starts; - TensorShapeVector slice_ends; - TensorShapeVector slice_axes; - - // note that conv objects are shared between execution frames, and a lock is needed to avoid multi-thread racing - std::mutex mutex; - IAllocatorUniquePtr memory_for_miopen_conv_results; - - ~MiopenConvState() { - if (b_zero) { - HIP_CALL_THROW(hipFree(b_zero)); - b_zero = nullptr; - } - } -}; - -enum : size_t { - AlgoSearchWorkspaceSize = 32 * 1024 * 1024, -}; - -// ONNX Conv operator uses NCHW format for input, weights and output. -// NhwcConv contrib ops uses NHWC format: last dimension of input, weights and output are channels. -template -class Conv : public RocmKernel { - public: - using HipT = typename ToHipType::MappedType; - - Conv(const OpKernelInfo& info) : RocmKernel(info), conv_attrs_(info) { - auto pads_size = conv_attrs_.pads.size(); - ORT_ENFORCE(pads_size % 2 == 0); - } - - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { - return GetScratchBuffer(s_.workspace_bytes, stream); - } - - Status UpdateState(OpKernelContext* context, bool bias_expected = false) const; - ConvAttributes conv_attrs_; - mutable MiopenConvState s_; - constexpr static auto kDefaultConvAlgo = miopenConvolutionFwdAlgoGEMM; - static const miopenConvFwdAlgorithm_t kAllAlgos[]; -}; - -Status SliceOutUnwantedOutputSection(hipStream_t stream, - const void* input_data, - gsl::span input_dims, - void* output_data, - const gsl::span& output_dims, - const gsl::span& starts, - const gsl::span& ends, - const gsl::span& axes, - size_t element_size); -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/nn/conv_impl.cu b/onnxruntime/core/providers/rocm/nn/conv_impl.cu deleted file mode 100644 index 98df9026a721b..0000000000000 --- a/onnxruntime/core/providers/rocm/nn/conv_impl.cu +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/nn/conv_impl.h" - -#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh" -#include "core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fast_divmod.h" - -namespace onnxruntime { -namespace rocm { - -template -void ConvBiasImpl( - hipStream_t stream, - const T1* lhs_data, - const T2* rhs_data, - T* output_data, - size_t bias_size, - size_t count) { - int output_rank_or_simple_broadcast = static_cast(SimpleBroadcast::RightPerChannelBatchN); - fast_divmod fdm_h(1); - fast_divmod fdm_c(bias_size); - - BinaryElementWiseImpl(stream, output_rank_or_simple_broadcast, nullptr, lhs_data, nullptr, rhs_data, - nullptr, fdm_h, fdm_c, output_data, OP_Add(), - count); -} - -template void ConvBiasImpl( - hipStream_t stream, - const float* lhs_data, - const float* rhs_data, - float* output_data, - size_t bias_size, - size_t count); - -template void ConvBiasImpl( - hipStream_t stream, - const half* lhs_data, - const half* rhs_data, - half* output_data, - size_t bias_size, - size_t count); - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/nn/conv_impl.h b/onnxruntime/core/providers/rocm/nn/conv_impl.h deleted file mode 100644 index befe0f4634e0d..0000000000000 --- a/onnxruntime/core/providers/rocm/nn/conv_impl.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace rocm { - -template -void ConvBiasImpl( - hipStream_t stream, - const T1* lhs_data, - const T2* rhs_data, - T* output_data, - size_t bias_size, - size_t count); - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc deleted file mode 100644 index a6848e90b406d..0000000000000 --- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "conv_transpose.h" - -namespace onnxruntime { -namespace rocm { - -// Op Set 11 for ConvTranspose only update document to clearify default dilations and strides value. -// which are already covered by op set 11 cpu version, so simply add declaration. -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 11, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); - -REGISTER_KERNEL_TYPED(float) -// not yet supported in MIOpen -// REGISTER_KERNEL_TYPED(double) -REGISTER_KERNEL_TYPED(MLFloat16) - -template -Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { - return DoConvTranspose(context, false); -} - -template -Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { - typedef typename ToHipType::MappedType HipT; - - const Tensor* X = context->Input(0); - const TensorShape& x_shape = X->Shape(); - auto x_dims = x_shape.AsShapeVector(); - auto x_data = reinterpret_cast(X->Data()); - - auto x_dimensions = X->Shape().NumDimensions(); - if (x_dimensions < 3 || x_dimensions > 5) { - // TODO: the error message should tell which operator raises it. - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", - " X: ", X->Shape().ToString().c_str()); - } - const Tensor* W = context->Input(1); - const TensorShape& w_shape = W->Shape(); - auto w_dims = w_shape.AsShapeVector(); - auto w_data = reinterpret_cast(W->Data()); - - size_t num_inputs = OpKernel::Node().InputDefs().size(); - bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; - - HipT* y_data = nullptr; - if (x_dimensions == 3) { - x_dims.insert(x_dims.begin() + 2, 1); - w_dims.insert(w_dims.begin() + 2, 1); - } - - { - std::lock_guard lock(s_.mutex); - // TODO: add a global cache if need to handle cases for multiple frames running simultaneously with different batch_size - bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); - bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); - if (input_dims_changed || w_dims_changed) { - if (input_dims_changed) - s_.last_x_dims = gsl::make_span(x_dims); - - if (w_dims_changed) { - s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_bwd_results.clear(); - } - - ConvTransposeAttributes::Prepare p; - ORT_RETURN_IF_ERROR(conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding)); - - auto y_dims = p.Y->Shape().AsShapeVector(); - if (x_dimensions == 3) { - y_dims.insert(y_dims.begin() + 2, 1); - p.kernel_shape.insert(p.kernel_shape.begin(), 1); - p.pads.insert(p.pads.begin(), 0); - p.pads.insert(p.pads.begin() + 2, 0); - p.strides.insert(p.strides.begin(), 1); - p.dilations.insert(p.dilations.begin(), 1); - } - s_.y_dims = gsl::make_span(y_dims); - - if (w_dims_changed) { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, MiopenTensor::GetDataType())); - } - - // Special case when there is a dim value of 0 in the shape. - // Return only after we have cached the following for subsequent runs : - // 1) `w_dims` in the `w_desc` - // 2) `y_dims` in s_.y_dims - if (p.Y->Shape().Size() == 0) { - return Status::OK(); - } - - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, MiopenTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, MiopenTensor::GetDataType())); - - miopenConvolutionMode_t mode = miopenConvolution; - ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, - gsl::narrow_cast(conv_transpose_attrs_.group), - mode, MiopenTensor::GetDataType())); - - if (has_bias) { - const auto& b_shape = p.B->Shape(); - ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); - std::vector b_dims(2 + p.kernel_shape.size()); - b_dims[0] = 1; // N - b_dims[1] = b_shape[0]; // C - for (size_t i = 0; i < p.kernel_shape.size(); i++) - b_dims[2 + i] = 1; - - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, MiopenTensor::GetDataType())); - } - - y_data = reinterpret_cast(p.Y->MutableData()); - - if (!s_.cached_benchmark_bwd_results.contains(x_dims)) { - IAllocatorUniquePtr algo_search_workspace = GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); - - miopenConvAlgoPerf_t perf; - int algo_count = 1; - MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm( - GetMiopenHandle(context), - s_.x_tensor, - x_data, - s_.w_desc, - w_data, - s_.conv_desc, - s_.y_tensor, - y_data, - 1, - &algo_count, - &perf, - algo_search_workspace.get(), - AlgoSearchWorkspaceSize, - false)); - s_.cached_benchmark_bwd_results.insert(x_dims, {perf.bwd_data_algo, perf.memory}); - } - - const auto& perf = s_.cached_benchmark_bwd_results.at(x_dims); - s_.bwd_data_algo = perf.bwd_data_algo; - s_.workspace_bytes = perf.memory; - } - - // The following block will be executed in case there has been no change in the shapes of the - // input and the filter compared to the previous run - if (!y_data) { - auto y_dims = s_.y_dims.AsShapeVector(); - if (x_dimensions == 3) { - y_dims.erase(y_dims.begin() + 2); - } - Tensor* Y = context->Output(0, TensorShape(y_dims)); - y_data = reinterpret_cast(Y->MutableData()); - - // Bail out early if one of the output dimensions is zero. - if (Y->Shape().Size() == 0) { - return Status::OK(); - } - } - - const auto alpha = Consts::One; - const auto beta = Consts::Zero; - - IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); - - MIOPEN_RETURN_IF_ERROR( - miopenConvolutionBackwardData( - GetMiopenHandle(context), - &alpha, - s_.x_tensor, - x_data, - s_.w_desc, - w_data, - s_.conv_desc, - s_.bwd_data_algo, - &beta, - s_.y_tensor, - y_data, - workspace.get(), - s_.workspace_bytes)); - - if (has_bias) { - const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); - auto b_data = reinterpret_cast(B->Data()); - MIOPEN_RETURN_IF_ERROR((miopenConvolutionForwardBias(GetMiopenHandle(context), &alpha, s_.b_tensor, b_data, &beta, s_.y_tensor, y_data))); - } - } - - return Status::OK(); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.h b/onnxruntime/core/providers/rocm/nn/conv_transpose.h deleted file mode 100644 index 55a84cc59fe92..0000000000000 --- a/onnxruntime/core/providers/rocm/nn/conv_transpose.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/rocm_kernel.h" -#include "core/providers/rocm/miopen_common.h" -#include "core/providers/rocm/nn/conv.h" -#include "core/providers/cpu/nn/conv_transpose_attributes.h" - -namespace onnxruntime { -namespace rocm { - -template -class ConvTranspose : public RocmKernel { - public: - ConvTranspose(const OpKernelInfo& info) : RocmKernel(info), conv_transpose_attrs_(info) { - static_assert(!NHWC, "AMD builds don't support usage of NHWC ops"); - }; - Status ComputeInternal(OpKernelContext* context) const override; - Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; - - private: - ConvTransposeAttributes conv_transpose_attrs_; - - mutable MiopenConvState s_; -}; - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/nn/pool.cc b/onnxruntime/core/providers/rocm/nn/pool.cc deleted file mode 100644 index 3a82ab598004b..0000000000000 --- a/onnxruntime/core/providers/rocm/nn/pool.cc +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/miopen_common.h" -#include "core/providers/rocm/nn/pool.h" -#include "core/providers/rocm/nn/max_pool_with_index.h" -#include "core/providers/rocm/math/unary_elementwise_ops_impl.h" -#include "core/providers/rocm/reduction/reduction_ops.h" - -using namespace onnxruntime::common; -namespace onnxruntime { -namespace rocm { - -#define POOLING_KERNEL_(op_name, data_type, pool, pool_type, since_version) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - kOnnxDomain, \ - since_version, \ - data_type, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - pool); - -#define POOLING_KERNEL(op_name, data_type, pool_type, since_version) \ - POOLING_KERNEL_(op_name, data_type, Pool, pool_type, since_version) - -#define GLOBAL_POOLING_KERNEL(op_name, data_type, pool_type, since_version) \ - POOLING_KERNEL_(op_name, data_type, GlobalPool, pool_type, since_version) - -#define POOLING_KERNEL_VERSIONED(op_name, data_type, pool_type, since_version, end_version) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - op_name, \ - kOnnxDomain, \ - since_version, \ - end_version, \ - data_type, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Pool); - -#define POOLING_KERNEL_WITH_INDICES(op_name, data_type, pool_type, since_version) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - kOnnxDomain, \ - since_version, \ - data_type, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ - Pool); - -#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, data_type, pool_type, since_version, end_version) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - op_name, \ - kOnnxDomain, \ - since_version, \ - end_version, \ - data_type, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ - Pool); - -POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 7, 9) -POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 7, 9) -POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 7, 9) -POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 10, 10) -POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 10, 10) -POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 10, 10) -// AveragePool and MaxPool op set 11 only update spec document on default value for dilations and strides. -POOLING_KERNEL(AveragePool, float, AveragePool, 11) -POOLING_KERNEL(AveragePool, double, AveragePool, 11) -POOLING_KERNEL(AveragePool, MLFloat16, AveragePool, 11) -POOLING_KERNEL_VERSIONED(MaxPool, float, MaxPool<1>, 1, 7) -POOLING_KERNEL_VERSIONED(MaxPool, double, MaxPool<1>, 1, 7) -POOLING_KERNEL_VERSIONED(MaxPool, MLFloat16, MaxPool<1>, 1, 7) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 8, 9) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 8, 9) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 8, 9) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 10, 10) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 10, 10) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 10, 10) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 11, 11) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 11, 11) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 11, 11) -POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, double, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 12) - -GLOBAL_POOLING_KERNEL(GlobalAveragePool, float, AveragePool, 1) -GLOBAL_POOLING_KERNEL(GlobalAveragePool, double, AveragePool, 1) -GLOBAL_POOLING_KERNEL(GlobalAveragePool, MLFloat16, AveragePool, 1) -GLOBAL_POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1) -GLOBAL_POOLING_KERNEL(GlobalMaxPool, double, MaxPool<1>, 1) -GLOBAL_POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1) - -class MiopenPoolingDescriptor final { - public: - MiopenPoolingDescriptor() : desc_(nullptr) { - } - - ~MiopenPoolingDescriptor() { - if (desc_ != nullptr) { - miopenDestroyPoolingDescriptor(desc_); - desc_ = nullptr; - } - } - - MiopenPoolingDescriptor(const MiopenPoolingDescriptor&) = delete; - MiopenPoolingDescriptor& operator=(const MiopenPoolingDescriptor&) = delete; - - Status Set(miopenPoolingMode_t mode, - const gsl::span& kernel_shape, - const gsl::span& pads, - const gsl::span& strides) { - if (!desc_) - MIOPEN_RETURN_IF_ERROR(miopenCreatePoolingDescriptor(&desc_)); - - int rank = gsl::narrow_cast(kernel_shape.size()); - InlinedVector window(rank); - InlinedVector padding(rank); - InlinedVector stride(rank); - for (int i = 0; i < rank; i++) { - window[i] = gsl::narrow_cast(kernel_shape[i]); - } - for (int i = 0; i < rank; i++) { - padding[i] = gsl::narrow_cast(pads[i]); - } - for (int i = 0; i < rank; i++) { - stride[i] = gsl::narrow_cast(strides[i]); - } - MIOPEN_RETURN_IF_ERROR(SetPoolingNdDescriptorHelper( - desc_, - mode, - MIOPEN_PROPAGATE_NAN, - rank, - window.data(), - padding.data(), - stride.data())); - - return Status::OK(); - } - - operator miopenPoolingDescriptor_t() const { return desc_; } - - private: - miopenPoolingDescriptor_t desc_; -}; - -template -Status Pool::ComputeInternal(OpKernelContext* context) const { - typedef typename ToHipType::MappedType HipT; - const Tensor* X = context->Input(0); - const TensorShape& x_shape = X->Shape(); - const auto x_dims = x_shape.GetDims(); - - if (x_shape.NumDimensions() < 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input dimension cannot be less than 3."); - } - - auto kernel_shape = pool_attrs_.kernel_shape; - auto pads = pool_attrs_.pads; - auto strides = pool_attrs_.strides; - ORT_ENFORCE(!this->pool_attrs_.global_pooling); - - auto y_dims = pool_attrs_.SetOutputSize(x_shape, x_shape[1], &pads); - TensorShape y_shape(y_dims); - Tensor* Y = context->Output(0, y_shape); - // special case when there is a dim value of 0 in the shape. - if (y_shape.Size() == 0) - return Status::OK(); - - auto x_data = reinterpret_cast(X->Data()); - auto y_data = reinterpret_cast(Y->MutableData()); - - TensorShapeVector x_dims_miopen(x_dims.begin(), x_dims.end()); - TensorShapeVector y_dims_miopen(y_dims); - if (kernel_shape.size() < 2) { - // miopen only takes 4D or 5D input, so pad dimensions if needed - x_dims_miopen.push_back(1); - y_dims_miopen.push_back(1); - pads.insert(pads.begin() + kernel_shape.size(), 0); - pads.insert(pads.end(), 0); - kernel_shape.push_back(1); - strides.push_back(1); - } - - miopenPoolingMode_t mode = miopenPoolingMax; - if constexpr (PoolType::type == onnxruntime::PoolType::kAveragePool) { - mode = pool_attrs_.count_include_pad ? miopenPoolingAverageInclusive - : miopenPoolingAverage; - } - MiopenPoolingDescriptor pooling_desc; - ORT_RETURN_IF_ERROR(pooling_desc.Set(mode, kernel_shape, pads, strides)); - - if constexpr (std::is_same::value || std::is_same::value) { - // Cast to float back and forth using temp buffer - const auto alpha = Consts::One; - const auto beta = Consts::Zero; - MiopenTensor x_tensor; - MiopenTensor y_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_miopen, MiopenTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_miopen, MiopenTensor::GetDataType())); - - const auto input_count = x_shape.Size(); - const auto output_count = y_shape.Size(); - - IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, context->GetComputeStream()); - auto temp_Y = GetScratchBuffer(output_count, context->GetComputeStream()); - Impl_Cast(Stream(context), reinterpret_cast(x_data), temp_X.get(), input_count); - MIOPEN_RETURN_IF_ERROR(PoolingForwardHelper(GetMiopenHandle(context), pooling_desc, &alpha, x_tensor, temp_X.get(), &beta, y_tensor, temp_Y.get())); - Impl_Cast(Stream(context), temp_Y.get(), y_data, output_count); - } else { - const auto alpha = Consts::One; - const auto beta = Consts::Zero; - MiopenTensor x_tensor; - MiopenTensor y_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_miopen, MiopenTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_miopen, MiopenTensor::GetDataType())); - - MIOPEN_RETURN_IF_ERROR(PoolingForwardHelper(GetMiopenHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data)); - } - - return Status::OK(); -} - -template -Status Pool>::ComputeInternal(OpKernelContext* context) const { - typedef typename ToHipType::MappedType HipT; - const Tensor* X = context->Input(0); - const TensorShape& x_shape = X->Shape(); - - if (x_shape.NumDimensions() < 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input dimension cannot be less than 3."); - } - - auto kernel_shape = this->pool_attrs_.kernel_shape; - auto pads = this->pool_attrs_.pads; - auto strides = this->pool_attrs_.strides; - ORT_ENFORCE(!this->pool_attrs_.global_pooling); - - auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, x_shape[1], &pads); - Tensor* Y = context->Output(0, TensorShape(y_dims)); - - // special case when there is a dim value of 0 in the shape. - if (Y->Shape().Size() == 0) - return Status::OK(); - - auto x_data = reinterpret_cast(X->Data()); - auto y_data = reinterpret_cast(Y->MutableData()); - - Tensor* I = context->Output(1, TensorShape(y_dims)); - if (nullptr != I || !this->pool_attrs_.default_dilations) { - auto i_data = nullptr == I ? nullptr : I->MutableData(); - MaxPoolWithIndex( - this->Stream(context), - x_shape, - TensorShape(y_dims), - kernel_shape, - strides, - pads, - this->pool_attrs_.dilations, - this->pool_attrs_.storage_order, - x_data, - y_data, - i_data); - } else { - ORT_RETURN_IF_ERROR((Pool>::ComputeInternal(context))); - } - return Status::OK(); -} - -template -Status GlobalPool::ComputeInternal(OpKernelContext* context) const { - using HipT = typename ToHipType::MappedType; - const Tensor* X = context->Input(0); - const TensorShape& x_shape = X->Shape(); - - if (x_shape.NumDimensions() < 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input dimension cannot be less than 3."); - } - - ORT_ENFORCE(this->pool_attrs_.global_pooling); - - miopenReduceTensorOp_t reduce_op; - if constexpr (PoolType::type == onnxruntime::PoolType::kAveragePool) { - reduce_op = MIOPEN_REDUCE_TENSOR_AVG; - } else if (PoolType::type == onnxruntime::PoolType::kMaxPool) { - reduce_op = MIOPEN_REDUCE_TENSOR_MAX; - } else { - ORT_NOT_IMPLEMENTED(); - } - - miopenDataType_t miopen_type_X = MiopenTensor::GetDataType(); - - MiopenReduceDescriptor reduce_desc; - if constexpr (std::is_same::value) { - ORT_RETURN_IF_ERROR(reduce_desc.Set( - reduce_op, MiopenTensor::GetDataType(), MIOPEN_REDUCE_TENSOR_FLATTENED_INDICES)); - } else { - ORT_RETURN_IF_ERROR(reduce_desc.Set(reduce_op, miopen_type_X, MIOPEN_REDUCE_TENSOR_FLATTENED_INDICES)); - } - - auto x_dims = x_shape.AsShapeVector(); - TensorShapeVector y_dims; - y_dims.resize(x_dims.size(), 1); - y_dims[0] = x_dims[0]; - y_dims[1] = x_dims[1]; - - Tensor* Y = context->Output(0, y_dims); - - MiopenTensor input_tensor; - MiopenTensor output_tensor; - ORT_RETURN_IF_ERROR(input_tensor.Set(x_dims, miopen_type_X)); - ORT_RETURN_IF_ERROR(output_tensor.Set(y_dims, miopen_type_X)); - - auto miopen_handle = this->GetMiopenHandle(context); - size_t workspace_bytes{}; - MIOPEN_RETURN_IF_ERROR(miopenGetReductionWorkspaceSize( - miopen_handle, reduce_desc, input_tensor, output_tensor, &workspace_bytes)); - auto workspace_buffer = RocmKernel::GetScratchBuffer(workspace_bytes, context->GetComputeStream()); - - size_t indices_bytes{}; - MIOPEN_RETURN_IF_ERROR(miopenGetReductionIndicesSize( - miopen_handle, reduce_desc, input_tensor, output_tensor, &indices_bytes)); - auto indices_buffer = RocmKernel::GetScratchBuffer(indices_bytes, context->GetComputeStream()); - - const auto one = ReduceConsts::One; - const auto zero = ReduceConsts::Zero; - - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - miopen_handle, reduce_desc, indices_buffer.get(), indices_bytes, workspace_buffer.get(), workspace_bytes, - &one, input_tensor, reinterpret_cast(X->DataRaw()), - &zero, output_tensor, reinterpret_cast(Y->MutableDataRaw()))); - - return Status::OK(); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/nn/pool.h b/onnxruntime/core/providers/rocm/nn/pool.h deleted file mode 100644 index 7554e544f3377..0000000000000 --- a/onnxruntime/core/providers/rocm/nn/pool.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/cpu/nn/pool_base.h" -#include "core/providers/rocm/miopen_common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace rocm { - -template -class Pool : public RocmKernel, public PoolBase { - public: - Pool(const OpKernelInfo& info) : RocmKernel(info), PoolBase(info) {} - - Status ComputeInternal(OpKernelContext* context) const override; -}; - -template -class Pool> final : public Pool> { - public: - Pool(const OpKernelInfo& info) : Pool>(info) {} - - Status ComputeInternal(OpKernelContext* context) const override; -}; - -template -class GlobalPool final : public Pool { - public: - GlobalPool(const OpKernelInfo& info) : Pool(info) {} - - Status ComputeInternal(OpKernelContext* context) const override; -}; - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc deleted file mode 100644 index d8b7e26d17b65..0000000000000 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ /dev/null @@ -1,909 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/cpu/tensor/utils.h" -#include "core/providers/rocm/reduction/reduction_ops.h" -#include "core/providers/rocm/miopen_common.h" -#include "core/providers/rocm/math/binary_elementwise_ops_impl.h" -#include "core/providers/rocm/math/binary_elementwise_ops.h" -#include "core/providers/rocm/math/unary_elementwise_ops_impl.h" -#ifdef ENABLE_TRAINING -#include "contrib_ops/cpu/aten_ops/aten_op.h" -#endif - -using namespace onnxruntime::common; -namespace onnxruntime { -namespace rocm { - -#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - begin, end, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - version, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()).InputMemoryType(OrtMemTypeCPUInput, 1), \ - name); - -#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ - REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \ - REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur) - -#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \ - REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \ - REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \ - REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13) - -// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. -template -template -Status ReduceKernel::ReduceKernelShared( - const T* X, - const TensorShape& input_shape, - OutT* Y, - const TensorShape& output_shape, - miopenReduceTensorOp_t miopen_reduce_op, - miopenHandle_t miopen_handle, - onnxruntime::Stream* stream, - TensorShapeVector& output_dims) const { - typedef typename ToHipType::MappedType HipT; - typedef typename ToHipType::MappedType HipOutT; - miopenDataType_t miopen_type_X = MiopenTensor::GetDataType(); - const auto rank = input_shape.NumDimensions(); - - auto hip_stream = stream ? static_cast(stream->GetHandle()) : nullptr; - // Block of fast matrix reduction. - if (fast_reduction_) { - int m{}, n{}; - const auto applicable_matrix_reduction = get_applicable_matrix_reduction( - miopen_reduce_op, input_shape.GetDims(), axes_, m, n); - switch (applicable_matrix_reduction) { - case ApplicableMatrixReduction::Rows: { - return reduce_matrix_rows( - hip_stream, - reinterpret_cast(X), - reinterpret_cast(Y), - m, n, false); - } - case ApplicableMatrixReduction::Columns: - // don't call reduce_matrix_columns() since it will reset initial output data - default: - break; - } - } - - int64_t input_count = input_shape.Size(); - IAllocatorUniquePtr temp_X; - if (ReduceTensorIndices == MIOPEN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same::value) { - // ArgMax/ArgMin with FP16 are not supported by miopen, so convert input to fp32 then call miopen - temp_X = GetScratchBuffer(input_count, stream); - miopen_type_X = miopenFloat; - Impl_Cast(hip_stream, reinterpret_cast(X), temp_X.get(), input_shape.Size()); - } - - // MIOpen requires at least 3D input, so pad 1s if needed - auto input_dims_miopen = input_shape.AsShapeVector(); - auto output_dims_miopen = output_dims; - if (rank < 3) { - TensorShapeVector pads(3 - rank, 1); - input_dims_miopen.insert(input_dims_miopen.end(), pads.begin(), pads.end()); - output_dims_miopen.insert(output_dims_miopen.end(), pads.begin(), pads.end()); - } - - MiopenReduceDescriptor reduce_desc; - if constexpr (std::is_same::value) - ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, MiopenTensor::GetDataType(), ReduceTensorIndices)); - else - ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, miopen_type_X, ReduceTensorIndices)); - const auto one = ReduceConsts::One; - const auto zero = ReduceConsts::Zero; - MiopenTensor input_tensor; - MiopenTensor output_tensor; - ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X)); - ORT_RETURN_IF_ERROR(output_tensor.Set(output_dims_miopen, miopen_type_X)); - size_t workspace_bytes = 0; - MIOPEN_RETURN_IF_ERROR(miopenGetReductionWorkspaceSize(miopen_handle, reduce_desc, input_tensor, output_tensor, &workspace_bytes)); - auto workspace_rocm = GetScratchBuffer(workspace_bytes, stream); - - size_t indices_bytes = 0; - MIOPEN_RETURN_IF_ERROR(miopenGetReductionIndicesSize(miopen_handle, reduce_desc, input_tensor, output_tensor, &indices_bytes)); - auto indices_rocm = GetScratchBuffer(indices_bytes, stream); - - // need to allocate a separate buffer for ArgMin/ArgMax comparison output - auto output_count = output_shape.Size(); - - if (ReduceTensorIndices == MIOPEN_REDUCE_TENSOR_NO_INDICES) { - IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); - HipT* input_data = nullptr; - if (calculate_sqt_) { - input_data_buffer = GetScratchBuffer(input_count, stream); - input_data = reinterpret_cast(input_data_buffer.get()); - fast_divmod tmp_div; - Impl_Mul(hip_stream, static_cast(SimpleBroadcast::NoBroadcast), nullptr, - reinterpret_cast(X), nullptr, - reinterpret_cast(X), nullptr, - tmp_div, tmp_div, - input_data, input_count); - } else if (log_sum_exp_) { - // Reduce max -- Max/Min will output indices data - MiopenReduceDescriptor reduce_max_desc; - ORT_RETURN_IF_ERROR(reduce_max_desc.Set(MIOPEN_REDUCE_TENSOR_MAX, miopen_type_X, MIOPEN_REDUCE_TENSOR_NO_INDICES)); - size_t indices_bytes_max = 0; - MIOPEN_RETURN_IF_ERROR(miopenGetReductionIndicesSize(miopen_handle, reduce_max_desc, input_tensor, output_tensor, &indices_bytes_max)); - auto indices_rocm_max = GetScratchBuffer(indices_bytes, stream); - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - miopen_handle, reduce_max_desc, indices_rocm_max.get(), indices_bytes_max, workspace_rocm.get(), workspace_bytes, - &one, input_tensor, reinterpret_cast(X), - &zero, output_tensor, reinterpret_cast(Y))); - - // Exp(X-ReduceMax) - const TensorShape rhs_shape(output_dims); - auto exp_result_buffer = GetScratchBuffer(input_count, stream); - auto exp_result = exp_result_buffer.get(); - auto log_sum_result_buffer = GetScratchBuffer(output_count, stream); - auto log_sum_result = log_sum_result_buffer.get(); - BinaryElementwisePreparation prepare; - ORT_RETURN_IF_ERROR(prepare.BinaryElementwiseBroadcastPrepareHelper(input_shape, rhs_shape, input_shape)); - Impl_Sub(hip_stream, - prepare.output_rank_or_simple_broadcast, - &prepare.lhs_padded_strides, - reinterpret_cast(X), - &prepare.rhs_padded_strides, - reinterpret_cast(Y), - &prepare.fdm_output_strides, - prepare.fdm_H, prepare.fdm_C, - reinterpret_cast(exp_result), input_count); - - Impl_Exp(hip_stream, reinterpret_cast(exp_result), - reinterpret_cast(exp_result), - input_count); - - // ReduceSum - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - miopen_handle, reduce_desc, indices_rocm.get(), indices_bytes, workspace_rocm.get(), workspace_bytes, - &one, input_tensor, exp_result, - &zero, output_tensor, reinterpret_cast(log_sum_result))); - - // Log(Sum) - Impl_Log(hip_stream, reinterpret_cast(log_sum_result), - reinterpret_cast(log_sum_result), - output_count); - - // Log + ReduceMax - fast_divmod tmp_div; - Impl_Add(hip_stream, static_cast(SimpleBroadcast::NoBroadcast), nullptr, - reinterpret_cast(log_sum_result), nullptr, - reinterpret_cast(Y), nullptr, - tmp_div, tmp_div, - reinterpret_cast(Y), output_count); - - return Status::OK(); - } - if (calculate_sqt_) { - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - miopen_handle, reduce_desc, indices_rocm.get(), indices_bytes, workspace_rocm.get(), workspace_bytes, - &one, input_tensor, input_data, - &zero, output_tensor, reinterpret_cast(Y))); - } else { - // miopenReduceTensor for ReduceSum has issue if input and output has same size, we just need to copy the data for this case - if (input_count == output_count) { - if (reinterpret_cast(Y) != reinterpret_cast(X)) { - HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y, X, input_count * sizeof(T), hipMemcpyDeviceToDevice, hip_stream)); - } - } else { - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - miopen_handle, reduce_desc, indices_rocm.get(), indices_bytes, workspace_rocm.get(), workspace_bytes, - &one, input_tensor, reinterpret_cast(X), - &zero, output_tensor, reinterpret_cast(Y))); - } - } - } else { // For ArgMax & ArgMin ops, use the indicies as the output with int64 type - if (temp_X) { - auto temp_output = GetScratchBuffer(output_count, stream); - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - miopen_handle, reduce_desc, indices_rocm.get(), indices_bytes, workspace_rocm.get(), workspace_bytes, - &one, input_tensor, temp_X.get(), - &zero, output_tensor, temp_output.get())); - } else { - auto temp_output = GetScratchBuffer(output_count, stream); - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - miopen_handle, reduce_desc, indices_rocm.get(), indices_bytes, workspace_rocm.get(), workspace_bytes, - &one, input_tensor, reinterpret_cast(X), - &zero, output_tensor, temp_output.get())); - } - - // MIOpen reduction index is uint32_t for now, cast it to int64_t according to ONNX spec - Impl_Cast(hip_stream, reinterpret_cast(indices_rocm.get()), reinterpret_cast(Y), output_count); - } - - if (calculate_log_) { - Impl_Log(hip_stream, reinterpret_cast(Y), - reinterpret_cast(Y), - output_count); - } - - return Status::OK(); -} - -// template Status ReduceKernel::ReduceKernelShared( -// const double* X, -// const TensorShape& input_shape, -// double* Y, -// const TensorShape& output_shape, -// miopenReduceTensorOp_t miopen_reduce_op, -// miopenHandle_t miopen_handle, -// onnxruntime::Stream* stream, -// TensorShapeVector& output_dims) const; - -template Status ReduceKernel::ReduceKernelShared( - const float* X, - const TensorShape& input_shape, - float* Y, - const TensorShape& output_shape, - miopenReduceTensorOp_t miopen_reduce_op, - miopenHandle_t miopen_handle, - onnxruntime::Stream* stream, - TensorShapeVector& output_dims) const; - -template Status ReduceKernel::ReduceKernelShared( - const MLFloat16* X, - const TensorShape& input_shape, - MLFloat16* Y, - const TensorShape& output_shape, - miopenReduceTensorOp_t miopen_reduce_op, - miopenHandle_t miopen_handle, - onnxruntime::Stream* stream, - TensorShapeVector& output_dims) const; - -// `input_shape_override` (if provided) is the input shape for compute purposes -Status PrepareForReduce(const Tensor* X, - bool keepdims, - gsl::span axes, - PrepareReduceMetadata& prepare_reduce_metadata, - const TensorShape* input_shape_override) { - ORT_ENFORCE(nullptr != X); - - const TensorShape& input_shape = input_shape_override ? *input_shape_override : X->Shape(); - const int64_t rank = gsl::narrow(input_shape.NumDimensions()); - prepare_reduce_metadata.input_count = input_shape.Size(); - - if (rank > 8) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "MIOpen only supports up to 8-D tensors in reduction"); - } - - const auto input_dims = input_shape.GetDims(); - std::vector reduced(rank, false); - if (axes.size() > 0) { - prepare_reduce_metadata.output_dims = input_shape.AsShapeVector(); - for (auto axis : axes) { - axis = HandleNegativeAxis(axis, rank); - ORT_ENFORCE(input_dims[axis] != 0, - "Can't reduce on dim with value of 0 if 'keepdims' is false. " - "Invalid output shape would be produced. input_shape:", - input_shape); - prepare_reduce_metadata.output_dims[axis] = 1; - reduced[axis] = true; - } - } else { - // no axes provided (i.e.) default axes => reduce on all dims - prepare_reduce_metadata.output_dims.reserve(input_dims.size()); - for (auto dim : input_dims) { - ORT_ENFORCE(keepdims || dim != 0, - "Can't reduce on dim with value of 0 if 'keepdims' is false. " - "Invalid output shape would be produced. input_shape:", - input_shape); - prepare_reduce_metadata.output_dims.push_back(dim == 0 ? 0 : 1); - } - } - - if (keepdims) { - prepare_reduce_metadata.squeezed_output_dims = prepare_reduce_metadata.output_dims; - } else if (axes.size() > 0) { - // we are not going to keep the reduced dims, hence compute the final output dim accordingly - prepare_reduce_metadata.squeezed_output_dims.reserve(rank); // even though we won't use the full capacity, it is better to reserve for peak possible usage - for (auto i = 0; i < rank; ++i) { - if (!reduced[i]) - prepare_reduce_metadata.squeezed_output_dims.push_back(input_dims[i]); - } - } else { - // 'axes' is empty and keepdims is false => we reduce on all axes AND drop all dims, - // so the result is just a scalar, we keep 'squeezed_output_dims' empty (i.e.) no-op - } - - // MIOpen requires at least 3D input, so pad 1s if needed - prepare_reduce_metadata.input_dims_miopen = input_shape.AsShapeVector(); - prepare_reduce_metadata.output_dims_miopen = prepare_reduce_metadata.output_dims; - if (rank < 3) { - TensorShapeVector pads(3 - rank, 1); - prepare_reduce_metadata.input_dims_miopen.insert(prepare_reduce_metadata.input_dims_miopen.end(), pads.begin(), pads.end()); - prepare_reduce_metadata.output_dims_miopen.insert(prepare_reduce_metadata.output_dims_miopen.end(), pads.begin(), pads.end()); - } - - prepare_reduce_metadata.output_count = TensorShape(prepare_reduce_metadata.output_dims).Size(); - - return Status::OK(); -} - -// `input_shape_override` is the input shape for compute purposes (if provided) -template -Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, - /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, - gsl::span axes, - bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, - const TensorShape* input_shape_override) { - typedef typename ToHipType::MappedType HipT; - const TensorShape& input_shape = input_shape_override ? *input_shape_override : input.Shape(); - hipStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - - int64_t input_count = prepare_reduce_metadata.input_count; - int64_t output_count = prepare_reduce_metadata.output_count; - auto& output_dims = prepare_reduce_metadata.output_dims; - auto& input_dims_miopen = prepare_reduce_metadata.input_dims_miopen; - auto& output_dims_miopen = prepare_reduce_metadata.output_dims_miopen; - // special case when there is a dim value of 0 in the shape. - if (input_count == 0) { - assert(output.Shape().Size() == 0); - return Status::OK(); - } - - // Block of fast matrix reduction. - if (fast_reduction) { - int m{}, n{}; - const auto applicable_matrix_reduction = - get_applicable_matrix_reduction(miopen_reduce_op, input_shape.GetDims(), axes, m, n); - if (applicable_matrix_reduction != ApplicableMatrixReduction::None) { - IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); - const HipT* input_data = reinterpret_cast(input.Data()); - if (calculate_sqt) { - input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream, WaitRocmNotificationOnDevice); - input_data = reinterpret_cast(input_data_buffer.get()); - fast_divmod tmp_div; - Impl_Mul(stream, static_cast(SimpleBroadcast::NoBroadcast), nullptr, - reinterpret_cast(input.Data()), nullptr, - reinterpret_cast(input.Data()), nullptr, tmp_div, tmp_div, - reinterpret_cast(input_data_buffer.get()), input_count); - input_data = reinterpret_cast(input_data_buffer.get()); - } - - switch (applicable_matrix_reduction) { - case ApplicableMatrixReduction::Rows: { - ORT_RETURN_IF_ERROR(reduce_matrix_rows( - stream, input_data, reinterpret_cast(output.MutableData()), m, n)); - } break; - case ApplicableMatrixReduction::Columns: { - const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size(m, n); - auto buffer = buffer_size_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, buffer_size_bytes, false, ort_stream, WaitRocmNotificationOnDevice); - ORT_RETURN_IF_ERROR(reduce_matrix_columns(stream, input_data, - reinterpret_cast(output.MutableData()), m, n, - buffer.get(), buffer_size_bytes)); - } break; - default: { - ORT_ENFORCE(false, "Invild matrix reduction type."); - } - } - - if (calculate_log) { - Impl_Log(stream, reinterpret_cast(output.Data()), - reinterpret_cast(output.MutableData()), output_count); - } else if (miopen_reduce_op == MIOPEN_REDUCE_TENSOR_AVG) { - float denominator_float = applicable_matrix_reduction == ApplicableMatrixReduction::Rows - ? static_cast(m) - : static_cast(n); - HipT denominator = ToHipType::FromFloat(denominator_float); - UnaryDiv(stream, reinterpret_cast(output.Data()), - reinterpret_cast(output.MutableData()), denominator, output_count); - } - - return Status::OK(); - } - } - - // This reduction keep adding values to this buffer. If a non-zero value, say 1000, is here, the sum will start with 1000. - // Therefore zeroing out the memory is required - HIP_RETURN_IF_ERROR(hipMemsetAsync(output.MutableDataRaw(), 0, output.SizeInBytes(), stream)); - - IAllocatorUniquePtr temp_X; - miopenDataType_t miopen_type_X = miopenFloat; - - // unlike bfp16 not supported in cudnn, miopen call for bfp16 succeeded below, however, UT shows data error - // so for now, follow the same logic in cudnn and convert input to fp32 then call miopen - if ((ReduceTensorIndices == MIOPEN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same::value) || - (ReduceTensorIndices == MIOPEN_REDUCE_TENSOR_NO_INDICES && std::is_same::value)) { - // ArgMax/ArgMin with FP16 are not supported by miopen, so convert input to fp32 then call miopen - temp_X = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream, WaitRocmNotificationOnDevice); - Impl_Cast(stream, reinterpret_cast(input.Data()), temp_X.get(), input_shape.Size()); - } else { - miopen_type_X = MiopenTensor::GetDataType(); - } - - MiopenReduceDescriptor reduce_desc; - if constexpr (std::is_same::value || std::is_same::value) { - ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, MiopenTensor::GetDataType(), ReduceTensorIndices)); - } else { - ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, miopen_type_X, ReduceTensorIndices)); - } - - const auto one = ReduceConsts::One; - const auto zero = ReduceConsts::Zero; - MiopenTensor input_tensor; - MiopenTensor output_tensor; - ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X)); - ORT_RETURN_IF_ERROR(output_tensor.Set(output_dims_miopen, miopen_type_X)); - size_t workspace_bytes = 0; - RocmStream* rocm_stream = static_cast(ort_stream); - MIOPEN_RETURN_IF_ERROR(miopenGetReductionWorkspaceSize(RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, - input_tensor, output_tensor, &workspace_bytes)); - auto workspace_rocm = workspace_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, workspace_bytes, false, ort_stream, WaitRocmNotificationOnDevice); - - size_t indices_bytes = 0; - MIOPEN_RETURN_IF_ERROR(miopenGetReductionIndicesSize(RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, - input_tensor, output_tensor, &indices_bytes)); - auto indices_rocm = indices_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream, WaitRocmNotificationOnDevice); - - if (ReduceTensorIndices == MIOPEN_REDUCE_TENSOR_NO_INDICES) { - IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); - HipT* input_data = nullptr; - if (calculate_sqt) { - input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream, WaitRocmNotificationOnDevice); - input_data = reinterpret_cast(input_data_buffer.get()); - fast_divmod tmp_div; - Impl_Mul(stream, - static_cast(SimpleBroadcast::NoBroadcast), nullptr, - reinterpret_cast(input.Data()), nullptr, - reinterpret_cast(input.Data()), nullptr, - tmp_div, tmp_div, - input_data, input_count); - } else if (log_sum_exp) { - // miopenReduceTensor for ReduceSum has issue if input and output has same size, we just need to copy the data for this case - // This happens when the input is Scalar - if (input_count == output_count) { - if (output.MutableData() != input.Data()) { - HIP_RETURN_IF_ERROR(hipMemcpyAsync(output.MutableData(), input.Data(), input_count * sizeof(T), hipMemcpyDeviceToDevice, stream)); - } - } else { - // Reduce max -- Max/Min will output indices data - MiopenReduceDescriptor reduce_max_desc; - miopenDataType_t miopen_reduce_max_type = miopen_type_X; - if ((std::is_same::value)) { - miopen_reduce_max_type = miopenFloat; - } - ORT_RETURN_IF_ERROR(reduce_max_desc.Set(MIOPEN_REDUCE_TENSOR_MAX, miopen_reduce_max_type, MIOPEN_REDUCE_TENSOR_NO_INDICES)); - size_t indices_bytes_max = 0; - MIOPEN_RETURN_IF_ERROR(miopenGetReductionIndicesSize(RocmKernel::GetMiopenHandle(rocm_stream), reduce_max_desc, - input_tensor, output_tensor, &indices_bytes_max)); - auto indices_rocm_max = indices_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream, WaitRocmNotificationOnDevice); - auto* p_output = reinterpret_cast(output.template MutableData()); - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - RocmKernel::GetMiopenHandle(rocm_stream), reduce_max_desc, indices_rocm_max.get(), indices_bytes_max, - workspace_rocm.get(), workspace_bytes, - &one, input_tensor, reinterpret_cast(input.Data()), - &zero, output_tensor, p_output)); - } - - // Exp(X-ReduceMax) - const TensorShape output_shape(output_dims); - auto exp_result_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream, WaitRocmNotificationOnDevice); - auto exp_result = exp_result_buffer.get(); - auto log_sum_result_buffer = output_count == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream, WaitRocmNotificationOnDevice); - auto log_sum_result = log_sum_result_buffer.get(); - BinaryElementwisePreparation prepare; - ORT_RETURN_IF_ERROR(prepare.BinaryElementwiseBroadcastPrepareHelper(input_shape, output_shape, input_shape)); - Impl_Sub(stream, - prepare.output_rank_or_simple_broadcast, - &prepare.lhs_padded_strides, - reinterpret_cast(input.Data()), - &prepare.rhs_padded_strides, - reinterpret_cast(output.MutableData()), - &prepare.fdm_output_strides, - prepare.fdm_H, prepare.fdm_C, - reinterpret_cast(exp_result), input_count); - - Impl_Exp(stream, - reinterpret_cast(exp_result), - reinterpret_cast(exp_result), - input_count); - - // miopenReduceTensor for ReduceSum has issue if input and output has same size, we just need to copy the data for this case - // This happens when the input is Scalar. We do not need to add anything in this case. - if (input_count == output_count) { - HIP_RETURN_IF_ERROR(hipMemcpyAsync(reinterpret_cast(log_sum_result), exp_result, input_count * sizeof(T), hipMemcpyDeviceToDevice, stream)); - } else { - // ReduceSum - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, indices_rocm.get(), indices_bytes, - workspace_rocm.get(), workspace_bytes, - &one, input_tensor, exp_result, - &zero, output_tensor, reinterpret_cast(log_sum_result))); - } - - // Log(Sum) - Impl_Log(stream, reinterpret_cast(log_sum_result), - reinterpret_cast(log_sum_result), - output_count); - - // Log + ReduceMax - fast_divmod tmp_div; - Impl_Add(stream, static_cast(SimpleBroadcast::NoBroadcast), nullptr, - reinterpret_cast(log_sum_result), nullptr, - reinterpret_cast(output.MutableData()), nullptr, - tmp_div, tmp_div, - reinterpret_cast(output.MutableData()), output_count); - - return Status::OK(); - } - if (calculate_sqt) { - // miopenReduceTensor for ReduceSum has issue if input and output has same size, we just need to copy the data for this case - // This happens when the input is Scalar. We do not need to add anything in this case. - if (input_count == output_count) { - HIP_RETURN_IF_ERROR(hipMemcpyAsync(reinterpret_cast(output.MutableData()), input_data, input_count * sizeof(T), hipMemcpyDeviceToDevice, stream)); - } else { - auto* p_output = reinterpret_cast(output.template MutableData()); - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, indices_rocm.get(), indices_bytes, - workspace_rocm.get(), workspace_bytes, - &one, input_tensor, input_data, - &zero, output_tensor, p_output)); - } - } else { - // miopenReduceTensor for ReduceSum has issue if input and output has same size, we just need to copy the data for this case - if (input_count == output_count) { - if (output.MutableData() != input.Data()) { - HIP_RETURN_IF_ERROR(hipMemcpyAsync(output.MutableData(), input.Data(), input_count * sizeof(T), hipMemcpyDeviceToDevice, stream)); - } - } else { - if (temp_X) { - auto temp_output = output_count == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream, WaitRocmNotificationOnDevice); - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, indices_rocm.get(), indices_bytes, - workspace_rocm.get(), workspace_bytes, - &one, input_tensor, temp_X.get(), - &zero, output_tensor, temp_output.get())); - - Impl_Cast(stream, temp_output.get(), reinterpret_cast(output.MutableData()), output_count); - } else { - auto* p_output = reinterpret_cast(output.template MutableData()); - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, indices_rocm.get(), indices_bytes, - workspace_rocm.get(), workspace_bytes, - &one, input_tensor, reinterpret_cast(input.Data()), - &zero, output_tensor, p_output)); - } - } - } - } else { - // For ArgMax & ArgMin ops, use the indicies as the output with int64 type - // miopenReduceTensor has issue if input and output has same size, which will happen if the axis to be reduced has dim value of 1. - // the output is zeros of the output size - if (input_count == output_count) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(output.MutableData(), static_cast(0), output_count * sizeof(int64_t), stream)); - } else { - if (temp_X) { - auto temp_output = output_count == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream, WaitRocmNotificationOnDevice); - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, indices_rocm.get(), indices_bytes, - workspace_rocm.get(), workspace_bytes, - &one, input_tensor, temp_X.get(), - &zero, output_tensor, temp_output.get())); - } else { - auto temp_output = output_count == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream, WaitRocmNotificationOnDevice); - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( - RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, indices_rocm.get(), indices_bytes, - workspace_rocm.get(), workspace_bytes, - &one, input_tensor, reinterpret_cast(input.Data()), - &zero, output_tensor, temp_output.get())); - } - - // MIOpen reduction index is uint32_t for now, cast it to int64_t according to ONNX spec - Impl_Cast(stream, reinterpret_cast(indices_rocm.get()), output.MutableData(), output_count); - } - } - - if (calculate_log) { - Impl_Log(stream, - reinterpret_cast(output.MutableData()), - reinterpret_cast(output.MutableData()), - output_count); - } - - return Status::OK(); -} - -template Status ReduceComputeCore( - const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, - /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, - gsl::span axes, - bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, - const TensorShape* input_shape_override); - -// template Status ReduceComputeCore( -// const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, -// /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, -// gsl::span axes, -// bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, -// Stream* ort_stream, -// const TensorShape* input_shape_override); - -template Status ReduceComputeCore( - const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, - /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, - gsl::span axes, - bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, - const TensorShape* input_shape_override); - -template -template -Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenReduceTensorOp_t miopen_reduce_op) const { - const Tensor* X = ctx->Input(0); - TensorShapeVector axes; - - size_t num_inputs = ctx->InputCount(); - const Tensor* axes_tensor = num_inputs == 2 ? ctx->Input(1) : nullptr; // optional input. may be nullptr. - if (axes_tensor != nullptr) { - // override the attribute value with the input value for reduction_axes - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); - auto nDims = static_cast(axes_tensor->Shape()[0]); - const auto* data = axes_tensor->Data(); - axes.assign(data, data + nDims); - } else { - axes.assign(axes_.begin(), axes_.end()); - } - - // empty axes and no-op - if (axes.empty() && noop_with_empty_axes_) { - auto* Y = ctx->Output(0, X->Shape()); - HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->MutableData(), X->Data(), X->SizeInBytes(), - hipMemcpyDeviceToDevice, Stream(ctx))); - return Status::OK(); - } - -#ifdef ENABLE_TRAINING - // Use ATen for ReduceSum if possible. - const TensorShape& input_shape = X->Shape(); - if (contrib::IsATenOperatorExecutorInitialized() && miopen_reduce_op == MIOPEN_REDUCE_TENSOR_ADD && !calculate_log_ && - !calculate_sqt_ && !log_sum_exp_ && input_shape.Size() > 0) { - if (axes.empty()) { - axes.resize(input_shape.NumDimensions()); - std::iota(axes.begin(), axes.end(), 0); - } - ORT_RETURN_IF_ERROR(contrib::ExecuteReduceSumATen(ctx, axes, keepdims_)); - return Status::OK(); - } -#endif - - PrepareReduceMetadata prepare_reduce_metadata; - ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, axes, prepare_reduce_metadata)); - Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); - const bool fast_reduction = fast_reduction_ && !ctx->GetUseDeterministicCompute(); - return ReduceComputeCore(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), *X, prepare_reduce_metadata, *Y, miopen_reduce_op, axes, - calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, ctx->GetComputeStream()); -} - -#define SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(T) \ - template <> \ - template <> \ - Status ReduceKernel::ComputeImpl( \ - OpKernelContext * ctx, miopenReduceTensorOp_t miopen_reduce_op) const { \ - typedef typename ToHipType::MappedType HipT; \ - const Tensor* X = ctx->Input(0); \ - TensorShapeVector axes; \ - size_t num_inputs = ctx->InputCount(); \ - if (num_inputs == 2) { \ - const Tensor* axes_tensor = ctx->Input(1); \ - ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); \ - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); \ - auto nDims = static_cast(axes_tensor->Shape()[0]); \ - const auto* data = axes_tensor->Data(); \ - axes.assign(data, data + nDims); \ - } else { \ - axes.assign(axes_.begin(), axes_.end()); \ - } \ - \ - if (axes.empty() && noop_with_empty_axes_) { \ - auto* Y = ctx->Output(0, X->Shape()); \ - HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->MutableData(), X->Data(), X->SizeInBytes(), \ - hipMemcpyDeviceToDevice, Stream(ctx))); \ - return Status::OK(); \ - } \ - \ - PrepareReduceMetadata prepare_reduce_metadata; \ - ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, axes, prepare_reduce_metadata)); \ - \ - Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); \ - \ - int64_t input_count = prepare_reduce_metadata.input_count; \ - int64_t output_count = prepare_reduce_metadata.output_count; \ - auto& input_dims_miopen = prepare_reduce_metadata.input_dims_miopen; \ - auto& output_dims_miopen = prepare_reduce_metadata.output_dims_miopen; \ - \ - if (input_count == 0) { \ - assert(Y->Shape().Size() == 0); \ - return Status::OK(); \ - } \ - \ - if (input_count == output_count) { \ - if (Y->MutableData() != X->Data()) { \ - HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->MutableData(), X->Data(), \ - input_count * sizeof(T), hipMemcpyDeviceToDevice, Stream(ctx))); \ - } \ - return Status::OK(); \ - } \ - \ - HIP_RETURN_IF_ERROR(hipMemsetAsync(Y->MutableDataRaw(), 0, Y->SizeInBytes(), Stream(ctx))); \ - \ - size_t indices_bytes = 0; \ - size_t workspace_bytes = 0; \ - MiopenTensor input_tensor; \ - MiopenTensor output_tensor; \ - MiopenReduceDescriptor reduce_desc; \ - \ - miopenDataType_t miopen_type_X = miopenFloat; \ - IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, ctx->GetComputeStream()); \ - Impl_Cast(Stream(ctx), reinterpret_cast(X->Data()), temp_X.get(), X->Shape().Size()); \ - \ - ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, miopen_type_X, MIOPEN_REDUCE_TENSOR_NO_INDICES)); \ - ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X)); \ - ORT_RETURN_IF_ERROR(output_tensor.Set(output_dims_miopen, miopen_type_X)); \ - MIOPEN_RETURN_IF_ERROR( \ - miopenGetReductionIndicesSize(GetMiopenHandle(ctx), reduce_desc, input_tensor, output_tensor, &indices_bytes)); \ - MIOPEN_RETURN_IF_ERROR( \ - miopenGetReductionWorkspaceSize(GetMiopenHandle(ctx), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); \ - IAllocatorUniquePtr indices_rocm = GetScratchBuffer(indices_bytes, ctx->GetComputeStream()); \ - IAllocatorUniquePtr workspace_rocm = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); \ - \ - const auto one = Consts::One; \ - const auto zero = Consts::Zero; \ - auto temp_Y = GetScratchBuffer(output_count, ctx->GetComputeStream()); \ - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(GetMiopenHandle(ctx), reduce_desc, indices_rocm.get(), indices_bytes, \ - workspace_rocm.get(), workspace_bytes, &one, input_tensor, temp_X.get(), \ - &zero, output_tensor, temp_Y.get())); \ - Impl_Cast(Stream(ctx), temp_Y.get(), reinterpret_cast(Y->MutableData()), output_count); \ - \ - return Status::OK(); \ - } - -SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int32_t) -SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int64_t) -SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int8_t) -SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(uint8_t) - -namespace ReductionOps { - -template -std::unique_ptr ReduceCompute(const AllocatorPtr& gpu_allocator, miopenReduceTensorOp_t miopen_reduce_op, AllocatorPtr allocator, - const Tensor& input, gsl::span axes, - bool keep_dims, bool calculate_log, bool calculate_sqt, bool log_sum_exp, - bool fast_reduction, Stream* stream, const TensorShape* input_shape_override) { - PrepareReduceMetadata prepare_reduce_metadata; - auto status = PrepareForReduce(&input, - keep_dims, - axes, - prepare_reduce_metadata, - input_shape_override); - - if (!status.IsOK()) { - ORT_THROW(ONNXRUNTIME, FAIL, "Failed to perform reduce op: ", status.ErrorMessage()); - } - - auto output = Tensor::Create(input.DataType(), prepare_reduce_metadata.squeezed_output_dims, std::move(allocator)); - - status = ReduceComputeCore(gpu_allocator, input, prepare_reduce_metadata, *output, miopen_reduce_op, axes, - calculate_log, calculate_sqt, log_sum_exp, fast_reduction, stream, input_shape_override); - - if (!status.IsOK()) { - ORT_THROW(ONNXRUNTIME, FAIL, "Failed to perform reduce op: ", status.ErrorMessage()); - } - - return output; -} - -// Explicit template instantiation (needed to be used in einsum_auxiliary_ops.cc) - -template std::unique_ptr ReduceCompute( - const AllocatorPtr& gpu_allocator, miopenReduceTensorOp_t miopen_reduce_op, - AllocatorPtr allocator, - const Tensor& input, gsl::span axes, - bool keep_dims, bool calculate_log, bool calculate_sqt, bool log_sum_exp, - bool fast_reduction, Stream* stream, const TensorShape* input_shape_override); - -// template std::unique_ptr ReduceCompute( -// ROCMExecutionProvider& rocm_ep, miopenReduceTensorOp_t miopen_reduce_op, -// AllocatorPtr allocator, -// const Tensor& input, gsl::span axes, -// bool keep_dims, bool calculate_log, bool calculate_sqt, bool log_sum_exp, -// bool fast_reduction, Stream* stream, const TensorShape* input_shape_override); - -template std::unique_ptr ReduceCompute( - const AllocatorPtr& gpu_allocator, miopenReduceTensorOp_t miopen_reduce_op, - AllocatorPtr allocator, - const Tensor& input, gsl::span axes, - bool keep_dims, bool calculate_log, bool calculate_sqt, bool log_sum_exp, - bool fast_reduction, Stream* stream, const TensorShape* input_shape_override); - -} // namespace ReductionOps - -REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, MLFloat16) -REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, float) -// REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, double) - -REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16) -REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float) -// REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double) - -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18) -// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int32_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int64_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int8_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, uint8_t, 17, 18) - -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, float, 17, 18) -// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, BFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, int32_t, 17, 18) - -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, float, 17, 18) -// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int32_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int64_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int8_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, uint8_t, 17, 18) - -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, float, 17, 18) -// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, BFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, int32_t, 17, 18) - -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, MLFloat16, 12, 13) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, float, 12, 13) -// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, double, 12, 13) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, int32_t, 12, 13) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, int64_t, 12, 13) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, BFloat16, 12, 13) - -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, float, 17, 18) -// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, BFloat16, 17, 18) - -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, float, 17, 18) -// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, BFloat16, 17, 18) - -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, float, 17, 18) -// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, BFloat16, 17, 18) - -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, float, 17, 18) -// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, BFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, int32_t, 17, 18) - -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, float, 17, 18) -// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, BFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, int32_t, 17, 18) - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_allocator.cc b/onnxruntime/core/providers/rocm/rocm_allocator.cc deleted file mode 100644 index 27861a567a7f4..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_allocator.cc +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "rocm_allocator.h" -#include "rocm_common.h" -#include "gpu_data_transfer.h" - -namespace onnxruntime { - -void ROCMAllocator::CheckDevice(bool throw_when_fail) const { -#ifndef NDEBUG - // check device to match at debug build - // if it's expected to change, call hipSetDevice instead of the check - int current_device; - auto hip_err = hipGetDevice(¤t_device); - if (hip_err == hipSuccess) { - ORT_ENFORCE(current_device == Info().id); - } else if (throw_when_fail) { - HIP_CALL_THROW(hip_err); - } -#else - ORT_UNUSED_PARAMETER(throw_when_fail); -#endif -} - -void ROCMAllocator::SetDevice(bool throw_when_fail) const { - int current_device; - auto hip_err = hipGetDevice(¤t_device); - if (hip_err == hipSuccess) { - int allocator_device_id = Info().id; - if (current_device != allocator_device_id) { - hip_err = hipSetDevice(allocator_device_id); - } - } - - if (hip_err != hipSuccess && throw_when_fail) { - HIP_CALL_THROW(hip_err); - } -} - -void* ROCMAllocator::Alloc(size_t size) { - SetDevice(true); - CheckDevice(true); - void* p = nullptr; - if (size > 0) { - // BFCArena was updated recently to handle the exception and adjust the request size - HIP_CALL_THROW(hipMalloc((void**)&p, size)); - } - return p; -} - -void ROCMAllocator::Free(void* p) { - SetDevice(false); - CheckDevice(false); // ignore ROCM failure when free - ORT_IGNORE_RETURN_VALUE(hipFree(p)); // do not throw error since it's OK for hipFree to fail during shutdown -} - -void* ROCMExternalAllocator::Alloc(size_t size) { - void* p = nullptr; - if (size > 0) { - p = alloc_(size); - - // review(codemzs): ORT_ENFORCE does not seem appropriate. - ORT_ENFORCE(p != nullptr); - } - - return p; -} - -void ROCMExternalAllocator::Free(void* p) { - free_(p); - std::lock_guard lock(lock_); - auto it = reserved_.find(p); - if (it != reserved_.end()) { - reserved_.erase(it); - if (empty_cache_) empty_cache_(); - } -} - -void* ROCMExternalAllocator::Reserve(size_t size) { - void* p = Alloc(size); - if (!p) return nullptr; - std::lock_guard lock(lock_); - ORT_ENFORCE(reserved_.find(p) == reserved_.end()); - reserved_.insert(p); - return p; -} - -void* ROCMPinnedAllocator::Alloc(size_t size) { - void* p = nullptr; - if (size > 0) { - HIP_CALL_THROW(hipHostMalloc((void**)&p, size)); - } - return p; -} - -void ROCMPinnedAllocator::Free(void* p) { - HIP_CALL_THROW(hipHostFree(p)); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_allocator.h b/onnxruntime/core/providers/rocm/rocm_allocator.h deleted file mode 100644 index ae7982ae6c618..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_allocator.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/inlined_containers.h" -#include "core/framework/allocator.h" -#include - -namespace onnxruntime { - -class ROCMAllocator : public IAllocator { - public: - ROCMAllocator(OrtDevice::DeviceId device_id, const char* name) - : IAllocator( - OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device_id), - OrtMemTypeDefault)) {} - void* Alloc(size_t size) override; - void Free(void* p) override; - - private: - void CheckDevice(bool throw_when_fail) const; - void SetDevice(bool throw_when_fail) const; -}; - -class ROCMExternalAllocator : public ROCMAllocator { - typedef void* (*ExternalAlloc)(size_t size); - typedef void (*ExternalFree)(void* p); - typedef void (*ExternalEmptyCache)(); - - public: - ROCMExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) - : ROCMAllocator(device_id, name) { - alloc_ = reinterpret_cast(alloc); - free_ = reinterpret_cast(free); - empty_cache_ = reinterpret_cast(empty_cache); - } - - void* Alloc(size_t size) override; - void Free(void* p) override; - void* Reserve(size_t size) override; - - private: - mutable std::mutex lock_; - ExternalAlloc alloc_; - ExternalFree free_; - ExternalEmptyCache empty_cache_; - InlinedHashSet reserved_; -}; - -// TODO: add a default constructor -class ROCMPinnedAllocator : public IAllocator { - public: - ROCMPinnedAllocator(const char* name) - : IAllocator( - OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, - 0 /*CPU device always with id 0*/), - OrtMemTypeCPUOutput)) {} - - void* Alloc(size_t size) override; - void Free(void* p) override; -}; -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_call.cc b/onnxruntime/core/providers/rocm/rocm_call.cc deleted file mode 100644 index a73ef9b34b4de..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_call.cc +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "shared_inc/rocm_call.h" -#include - -#ifdef _WIN32 -#else // POSIX -#include -#include -#endif - -namespace onnxruntime { - -using namespace common; - -template -const char* RocmErrString(ERRTYPE) { - ORT_NOT_IMPLEMENTED(); -} - -#define CASE_ENUM_TO_STR(x) \ - case x: \ - return #x - -template <> -const char* RocmErrString(hipError_t x) { - ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard - return hipGetErrorString(x); -} - -template <> -const char* RocmErrString(rocblas_status e) { - ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard - switch (e) { - CASE_ENUM_TO_STR(rocblas_status_success); - CASE_ENUM_TO_STR(rocblas_status_invalid_handle); - CASE_ENUM_TO_STR(rocblas_status_not_implemented); - CASE_ENUM_TO_STR(rocblas_status_invalid_pointer); - CASE_ENUM_TO_STR(rocblas_status_size_query_mismatch); - CASE_ENUM_TO_STR(rocblas_status_invalid_size); - CASE_ENUM_TO_STR(rocblas_status_memory_error); - CASE_ENUM_TO_STR(rocblas_status_internal_error); - CASE_ENUM_TO_STR(rocblas_status_perf_degraded); - CASE_ENUM_TO_STR(rocblas_status_size_increased); - CASE_ENUM_TO_STR(rocblas_status_size_unchanged); - CASE_ENUM_TO_STR(rocblas_status_invalid_value); - CASE_ENUM_TO_STR(rocblas_status_continue); - default: - return "(look for rocblas_status in rocblas-types.h)"; - } -} - -template <> -const char* RocmErrString(hipblasStatus_t e) { - ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard - switch (e) { - CASE_ENUM_TO_STR(HIPBLAS_STATUS_SUCCESS); - CASE_ENUM_TO_STR(HIPBLAS_STATUS_NOT_INITIALIZED); - CASE_ENUM_TO_STR(HIPBLAS_STATUS_ALLOC_FAILED); - CASE_ENUM_TO_STR(HIPBLAS_STATUS_INVALID_VALUE); - CASE_ENUM_TO_STR(HIPBLAS_STATUS_ARCH_MISMATCH); - CASE_ENUM_TO_STR(HIPBLAS_STATUS_MAPPING_ERROR); - CASE_ENUM_TO_STR(HIPBLAS_STATUS_EXECUTION_FAILED); - CASE_ENUM_TO_STR(HIPBLAS_STATUS_INTERNAL_ERROR); - CASE_ENUM_TO_STR(HIPBLAS_STATUS_NOT_SUPPORTED); - default: - return "(look for HIPBLAS_STATUS_xxx in hipblas_api.h)"; - } -} - -template <> -const char* RocmErrString(hiprandStatus_t) { - ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard - return "(see hiprand.h & look for hiprandStatus_t or HIPRAND_STATUS_xxx)"; -} - -template <> -const char* RocmErrString(miopenStatus_t e) { - ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard - return miopenGetErrorString(e); -} - -template <> -const char* RocmErrString(hipfftResult e) { - ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard - switch (e) { - CASE_ENUM_TO_STR(HIPFFT_SUCCESS); - CASE_ENUM_TO_STR(HIPFFT_ALLOC_FAILED); - CASE_ENUM_TO_STR(HIPFFT_INVALID_VALUE); - CASE_ENUM_TO_STR(HIPFFT_INTERNAL_ERROR); - CASE_ENUM_TO_STR(HIPFFT_SETUP_FAILED); - CASE_ENUM_TO_STR(HIPFFT_INVALID_SIZE); - default: - return "Unknown hipfft error status"; - } -} - -#ifdef ORT_USE_NCCL -template <> -const char* RocmErrString(ncclResult_t e) { - ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard - return ncclGetErrorString(e); -} -#endif - -template -std::conditional_t RocmCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { - if (retCode != successCode) { - try { -#ifdef _WIN32 - std::string hostname_str = GetEnvironmentVar("COMPUTERNAME"); - if (hostname_str.empty()) { - hostname_str = "?"; - } - const char* hostname = hostname_str.c_str(); -#else - char hostname[HOST_NAME_MAX]; - if (gethostname(hostname, HOST_NAME_MAX) != 0) - strcpy(hostname, "?"); -#endif - int currentHipDevice = -1; - ORT_IGNORE_RETURN_VALUE(hipGetDevice(¤tHipDevice)); // void to silence nodiscard - ORT_IGNORE_RETURN_VALUE(hipGetLastError()); // clear last ROCM error; void to silence nodiscard - static char str[1024]; - snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", - libName, (int)retCode, RocmErrString(retCode), currentHipDevice, - hostname, - file, line, exprString, msg); - if constexpr (THRW) { - // throw an exception with the error info - ORT_THROW(str); - } else { - LOGS_DEFAULT(ERROR) << str; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); - } - } catch (const std::exception& e) { // catch, log, and rethrow since ROCM code sometimes hangs in destruction, so we'd never get to see the error - if constexpr (THRW) { - ORT_THROW(e.what()); - } else { - LOGS_DEFAULT(ERROR) << e.what(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()); - } - } - } - if constexpr (!THRW) { - return Status::OK(); - } -} - -template Status RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); -template Status RocmCall(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line); -template Status RocmCall(rocblas_status retCode, const char* exprString, const char* libName, rocblas_status successCode, const char* msg, const char* file, const int line); -template void RocmCall(rocblas_status retCode, const char* exprString, const char* libName, rocblas_status successCode, const char* msg, const char* file, const int line); -template Status RocmCall(miopenStatus_t retCode, const char* exprString, const char* libName, miopenStatus_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(miopenStatus_t retCode, const char* exprString, const char* libName, miopenStatus_t successCode, const char* msg, const char* file, const int line); -template Status RocmCall(hiprandStatus_t retCode, const char* exprString, const char* libName, hiprandStatus_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(hiprandStatus_t retCode, const char* exprString, const char* libName, hiprandStatus_t successCode, const char* msg, const char* file, const int line); -template Status RocmCall(hipfftResult retCode, const char* exprString, const char* libName, hipfftResult successCode, const char* msg, const char* file, const int line); -template void RocmCall(hipfftResult retCode, const char* exprString, const char* libName, hipfftResult successCode, const char* msg, const char* file, const int line); -template Status RocmCall(rsmi_status_t retCode, const char* exprString, const char* libName, rsmi_status_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(rsmi_status_t retCode, const char* exprString, const char* libName, rsmi_status_t successCode, const char* msg, const char* file, const int line); - -#ifdef ORT_USE_NCCL -template Status RocmCall(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line); -#endif - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_common.h b/onnxruntime/core/providers/rocm/rocm_common.h deleted file mode 100644 index 4af1f40a6fccc..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_common.h +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/shared_library/provider_api.h" -#include "core/common/status.h" -#include "core/framework/float16.h" -#include "core/providers/rocm/rocm_pch.h" -#include "core/providers/rocm/shared_inc/rocm_call.h" -#include "core/providers/rocm/shared_inc/fast_divmod.h" -#include "core/util/math.h" -#include - -namespace onnxruntime { -namespace rocm { - -#define HIP_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIP_CALL(expr)) -#define ROCBLAS_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(ROCBLAS_CALL(expr)) -#define HIPBLAS_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPBLAS_CALL(expr)) -#define HIPSPARSE_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPSPARSE_CALL(expr)) -#define HIPRAND_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPRAND_CALL(expr)) -#define MIOPEN_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(MIOPEN_CALL(expr)) -#define MIOPEN2_RETURN_IF_ERROR(expr, m) ORT_RETURN_IF_ERROR(MIOPEN_CALL2(expr, m)) -#define HIPFFT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPFFT_CALL(expr)) - -#ifdef USE_HIPBLASLT -#define HIPBLASLT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPBLASLT_CALL(expr)) -#endif - -// Type mapping for MLFloat16 to half -template -class ToHipType { - public: - typedef T MappedType; - static MappedType FromFloat(float f) { - return static_cast(f); - } -}; - -template <> -class ToHipType { - public: - typedef __half MappedType; - static MappedType FromFloat(float f) { - uint16_t h = math::floatToHalf(f); - return *reinterpret_cast(&h); - } -}; - -inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { - int stride = 1; - if (dims.empty() || p.size() < dims.size()) - return false; - auto rank = p.size(); - for (size_t i = 0; i < rank; i++) { - p[rank - 1 - i] = fast_divmod(stride); - if (i < dims.size() - 1) { - stride *= static_cast(dims[dims.size() - 1 - i]); - } - } - return true; -} - -inline int warpSizeDynamic() { - hipDeviceProp_t deviceProp; - HIP_CALL_THROW(hipGetDeviceProperties(&deviceProp, 0)); - return deviceProp.warpSize; -} - -inline void hipMemGetInfoAlt(uint32_t deviceId, size_t* pFree, size_t* pTotal) { - const auto status = hipMemGetInfo(pFree, pTotal); - if (status != hipSuccess) { - size_t usedMemory = 0; - ROCMSMI_CALL_THROW(rsmi_init(0)); - ROCMSMI_CALL_THROW(rsmi_dev_memory_total_get(deviceId, RSMI_MEM_TYPE_VIS_VRAM, pTotal)); - ROCMSMI_CALL_THROW(rsmi_dev_memory_usage_get(deviceId, RSMI_MEM_TYPE_VIS_VRAM, &usedMemory)); - *pFree = *pTotal - usedMemory; - ROCMSMI_CALL_THROW(rsmi_shut_down()); - } -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc deleted file mode 100644 index 6fcf23e346b6a..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ /dev/null @@ -1,2555 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/inlined_containers.h" -#include "core/providers/shared_library/provider_api.h" -#include "core/platform/env_var_utils.h" -#include "core/providers/rocm/rocm_execution_provider.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/rocm_allocator.h" -#include "core/providers/rocm/rocm_fwd.h" -#include "core/providers/rocm/gpu_data_transfer.h" -#include "core/providers/rocm/rocm_profiler.h" -#include "core/session/onnxruntime_run_options_config_keys.h" - -#ifndef DISABLE_CONTRIB_OPS -#include "contrib_ops/rocm/rocm_contrib_kernels.h" -#endif - -#ifdef ENABLE_TRAINING_OPS -#include "orttraining/training_ops/rocm/rocm_training_kernels.h" -#endif - -#ifdef USE_TRITON_KERNEL -#include "core/providers/rocm/triton_kernel.h" -#endif - -#include "core/providers/rocm/rocm_stream_handle.h" - -using namespace onnxruntime::common; - -namespace onnxruntime { - -class Memcpy final : public OpKernel { - public: - Memcpy(const OpKernelInfo& info) : OpKernel{info} {} - - Status Compute(OpKernelContext* ctx) const override { - auto X_type = ctx->InputType(0); - if (X_type->IsTensorType()) { - const auto* X = ctx->Input(0); - ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); - Tensor* Y = ctx->Output(0, X->Shape()); - ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); - // CopyTensorAsync could handle both pinned memory and non-pinned CPU memory. - // For non-pinned CPU memory, the copy is synchronous. - ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream())); - return Status::OK(); - } else { - if (X_type->IsSparseTensorType()) { - // TODO: support aysnc copy for sparse tensor - // sync the stream first, since it is a sync memory copy - HIP_CALL_THROW(hipStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle()))); - const auto* X = ctx->Input(0); - ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); - SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape()); - ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor."); - return X->Copy(Info().GetDataTransferManager(), *Y); - } else if (X_type->IsTensorSequenceType()) { - const TensorSeq* X = ctx->Input(0); - ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr."); - TensorSeq* Y = ctx->Output(0); - ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence."); - auto X_dtype = X->DataType(); - Y->SetType(X_dtype); - AllocatorPtr alloc; - - // If we are copying contents to ROCM, the allocator to use - // to allocate the buffers of the new tensors in the sequence - // can be temp space allocator associated with the ROCM EP - if (Node().OpType() == "MemcpyFromHost") { - auto status = ctx->GetTempSpaceAllocator(&alloc); - if (!status.IsOK()) { - return Status(common::ONNXRUNTIME, common::FAIL, - "Memcpy rocm: unable to get an allocator."); - } - } else { - // If we are copying contents to CPU (op type is "MemcpyToHost"), - // the allocator to use to allocate the buffers of the new tensors - // in the sequence will be the allocator from the CPU EP - auto status = ctx->GetTempSpaceCPUAllocator(&alloc); - if (!status.IsOK()) { - return Status(common::ONNXRUNTIME, common::FAIL, - "Memcpy rocm: unable to get the CPU allocator."); - } - } - auto X_size = X->Size(); - Y->Reserve(X_size); - for (size_t i = 0; i < X_size; ++i) { - const Tensor& source_tensor = X->Get(i); - std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device, - target_tensor->Location().device); - ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, *ctx->GetComputeStream())); - Y->Add(std::move(*target_tensor)); - } - return Status::OK(); - } - return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type."); - } - } -}; - -namespace rocm { -ONNX_OPERATOR_KERNEL_EX( - MemcpyFromHost, - kOnnxDomain, - 1, - kRocmExecutionProvider, - (*KernelDefBuilder::Create()) - .InputMemoryType(OrtMemTypeCPUInput, 0) - .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()), - Memcpy); - -ONNX_OPERATOR_KERNEL_EX( - MemcpyToHost, - kOnnxDomain, - 1, - kRocmExecutionProvider, - (*KernelDefBuilder::Create()) - .OutputMemoryType(OrtMemTypeCPUOutput, 0) - .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()), - Memcpy); - -} // namespace rocm - -AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId device_id, - size_t gpu_mem_limit, - ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo external_allocator_info, - const OrtArenaCfg* default_memory_arena_cfg) { - if (external_allocator_info.UseExternalAllocator()) { - AllocatorCreationInfo default_memory_info( - [external_allocator_info](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP, - external_allocator_info.alloc, - external_allocator_info.free, - external_allocator_info.empty_cache); - }, - device_id, - false); - - return CreateAllocator(default_memory_info); - } else { - AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP); - }, - device_id, - true, - {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, - // make it stream aware - true, - // enable cross stream sharing? - false); - - // ROCM malloc/free is expensive so always use an arena - return CreateAllocator(default_memory_info); - } -} - -ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, size_t /*gpu_mem_limit*/, - ArenaExtendStrategy /*arena_extend_strategy*/, ROCMExecutionProviderExternalAllocatorInfo /*external_allocator_info*/, - OrtArenaCfg* /*default_memory_arena_cfg*/) { - HIP_CALL_THROW(hipSetDevice(device_id)); - - HIPBLAS_CALL_THROW(hipblasCreate(&hipblas_handle_)); - HIPBLAS_CALL_THROW(hipblasSetStream(hipblas_handle_, stream)); - - MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_)); - MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream)); - - hip_graph_.SetStream(stream); -} - -ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { - ORT_IGNORE_RETURN_VALUE(HIPBLAS_CALL(hipblasDestroy(hipblas_handle_))); - ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_))); -} - -bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed( - RocmGraphAnnotation_t hip_graph_annotation_id) const { - if (!IsGraphCaptureAllowedOnRun(hip_graph_annotation_id)) { - return false; - } - if (graph_id_to_run_count_.find(hip_graph_annotation_id) == graph_id_to_run_count_.end()) { - return false; - } - return graph_id_to_run_count_.at(hip_graph_annotation_id) >= min_num_runs_before_hip_graph_capture_; -} - -bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun( - RocmGraphAnnotation_t hip_graph_annotation_id) const { - return hip_graph_.IsGraphCaptureAllowedOnRun(hip_graph_annotation_id); -} - -RocmGraphAnnotation_t ROCMExecutionProvider::PerThreadContext::GetRocmGraphAnnotationId( - const onnxruntime::RunOptions& run_options) const { - auto graph_annotation_str = - run_options.GetConfigOptions().GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); - // If graph annotation is not provided, fall back to the one hip graph per session behavior - RocmGraphAnnotation_t hip_graph_annotation_id = 0; - if (graph_annotation_str.has_value()) { - ORT_ENFORCE(TryParseStringWithClassicLocale(*graph_annotation_str, hip_graph_annotation_id), - "Failed to parse the hip graph annotation id: ", - *graph_annotation_str); - } - - return hip_graph_annotation_id; -} - -void ROCMExecutionProvider::PerThreadContext::CaptureBegin(RocmGraphAnnotation_t hip_graph_annotation_id) { - hip_graph_.CaptureBegin(hip_graph_annotation_id); -} - -void ROCMExecutionProvider::PerThreadContext::CaptureEnd(RocmGraphAnnotation_t hip_graph_annotation_id) { - hip_graph_.CaptureEnd(hip_graph_annotation_id); -} - -bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured(RocmGraphAnnotation_t graph_annotation_id) const { - return hip_graph_.IsGraphCaptured(graph_annotation_id); -} - -Status ROCMExecutionProvider::PerThreadContext::ReplayGraph(RocmGraphAnnotation_t graph_annotation_id) { - return hip_graph_.Replay(graph_annotation_id); -} - -void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture( - RocmGraphAnnotation_t hip_graph_annotation_id) { - if (graph_id_to_run_count_.find(hip_graph_annotation_id) == graph_id_to_run_count_.end()) { - graph_id_to_run_count_[hip_graph_annotation_id] = 1; - return; - } - graph_id_to_run_count_[hip_graph_annotation_id]++; -} - -void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { - if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable( - "ORT_ROCM_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead."); - env_tunable_op_enable.has_value() && env_tunable_op_enable != info.tunable_op.enable) { - LOGS_DEFAULT(INFO) << "ORT_ROCM_TUNABLE_OP_ENABLE is set to " << *env_tunable_op_enable; - info.tunable_op.enable = *env_tunable_op_enable; - } - - if (auto env_tunable_op_tuning_enable = onnxruntime::ParseTestOnlyEnvironmentVariable( - "ORT_ROCM_TUNABLE_OP_TUNING_ENABLE", {"0", "1"}, - "Use provider_options \"tunable_op_tuning_enable\" instead."); - env_tunable_op_tuning_enable.has_value() && env_tunable_op_tuning_enable != info.tunable_op.tuning_enable) { - LOGS_DEFAULT(INFO) << "ORT_ROCM_TUNABLE_OP_TUNING_ENABLE is set to " << *env_tunable_op_tuning_enable; - info.tunable_op.tuning_enable = *env_tunable_op_tuning_enable; - } - - if (info.tunable_op.tuning_enable && !info.tunable_op.enable) { - LOGS_DEFAULT(WARNING) << "TunableOp is enabled for tuning but is not enabled for using. This will have no effect."; - } -} - -ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kRocmExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, - info.device_id)}, - info_{info}, - tuning_context_(this, &info_.tunable_op) { - HIP_CALL_THROW(hipSetDevice(info_.device_id)); - - // must wait GPU idle, otherwise hipGetDeviceProperties might fail - HIP_CALL_THROW(hipDeviceSynchronize()); - HIP_CALL_THROW(hipGetDeviceProperties(&device_prop_, info_.device_id)); - - // This scenario is not supported. - ORT_ENFORCE(!(info.has_user_compute_stream && info.external_allocator_info.UseExternalAllocator())); - - if (info.has_user_compute_stream) { - external_stream_ = true; - use_ep_level_unified_stream_ = true; - stream_ = static_cast(info.user_compute_stream); - } else { - if (info.external_allocator_info.UseExternalAllocator()) { - use_ep_level_unified_stream_ = true; - stream_ = nullptr; - } else if (info.enable_hip_graph) { - // current hip graph implementation only works with single stream - // use EP level unified stream for all the reqeust - HIP_CALL_THROW(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking)); - use_ep_level_unified_stream_ = true; - } else { - stream_ = nullptr; - } - } - - size_t free = 0; - size_t total = 0; - onnxruntime::rocm::hipMemGetInfoAlt(info_.device_id, &free, &total); - - OverrideTunableOpInfoByEnv(info_); - -#ifdef USE_TRITON_KERNEL - onnxruntime::rocm::LoadOrtTritonKernel(); -#endif -} - -ROCMExecutionProvider::~ROCMExecutionProvider() { - // clean up thread local context caches - { - std::lock_guard lock(context_state_.mutex); - for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) { - const auto cache = cache_weak.lock(); - if (!cache) continue; - ORT_IGNORE_RETURN_VALUE(cache->erase(this)); - } - } - - if (!external_stream_ && stream_) { - ORT_IGNORE_RETURN_VALUE(HIP_CALL(hipStreamDestroy(stream_))); - } -} - -ITuningContext* ROCMExecutionProvider::GetTuningContext() const { - return const_cast(&tuning_context_); -} - -std::unique_ptr ROCMExecutionProvider::GetProfiler() { - return std::make_unique(); -} - -ROCMExecutionProvider::PerThreadContext& ROCMExecutionProvider::GetPerThreadContext() const { - const auto& per_thread_context_cache = PerThreadContextCache(); - - // try to use cached context - auto cached_context_it = per_thread_context_cache->find(this); - if (cached_context_it != per_thread_context_cache->end()) { - auto cached_context = cached_context_it->second.lock(); - ORT_ENFORCE(cached_context); - return *cached_context; - } - - // get context and update cache - std::shared_ptr context; - { - std::lock_guard lock(context_state_.mutex); - - // get or create a context - if (context_state_.retired_context_pool.empty()) { - context = std::make_shared(info_.device_id, stream_, info_.gpu_mem_limit, - info_.arena_extend_strategy, info_.external_allocator_info, info_.default_memory_arena_cfg); - } else { - context = context_state_.retired_context_pool.back(); - context_state_.retired_context_pool.pop_back(); - } - - // insert into active_contexts, should not already be present - const auto active_contexts_insert_result = context_state_.active_contexts.insert(context); - ORT_ENFORCE(active_contexts_insert_result.second); - - // insert into caches_to_update_on_destruction, may already be present - ORT_IGNORE_RETURN_VALUE(context_state_.caches_to_update_on_destruction.insert(per_thread_context_cache)); - } - - per_thread_context_cache->insert(std::make_pair(this, context)); - - return *context; -} - -void ROCMExecutionProvider::ReleasePerThreadContext() const { - const auto& per_thread_context_cache = PerThreadContextCache(); - - auto cached_context_it = per_thread_context_cache->find(this); - ORT_ENFORCE(cached_context_it != per_thread_context_cache->end()); - auto cached_context = cached_context_it->second.lock(); - ORT_ENFORCE(cached_context); - - { - std::lock_guard lock(context_state_.mutex); - context_state_.active_contexts.erase(cached_context); - context_state_.retired_context_pool.push_back(cached_context); - } - - per_thread_context_cache->erase(cached_context_it); -} - -Status ROCMExecutionProvider::Sync() const { - HIP_RETURN_IF_ERROR(hipDeviceSynchronize()); - return Status::OK(); -} - -Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { - // always set ROCM device when session::Run() in case it runs in a worker thread - HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); - RocmGraphAnnotation_t hip_graph_annotation_id = GetPerThreadContext().GetRocmGraphAnnotationId(run_options); - if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(hip_graph_annotation_id) && - GetPerThreadContext().IsGraphCaptureAllowed(hip_graph_annotation_id)) { - LOGS(*GetLogger(), INFO) << "Capturing the hip graph for this model"; - GetPerThreadContext().CaptureBegin(hip_graph_annotation_id); - } - return Status::OK(); -} - -Status ROCMExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) { - RocmGraphAnnotation_t hip_graph_annotation_id = GetPerThreadContext().GetRocmGraphAnnotationId(run_options); - if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(hip_graph_annotation_id)) { - if (GetPerThreadContext().IsGraphCaptureAllowed(hip_graph_annotation_id)) { - GetPerThreadContext().CaptureEnd(hip_graph_annotation_id); - // HIP work issued to a capturing stream doesn’t actually run on the GPU, - // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(hip_graph_annotation_id)); - } else { - GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(hip_graph_annotation_id); - } - } - - if (sync_stream) { - HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream_))); - } - - // The reason of !IsGraphCaptureEnabled(): - // If hip graph is enabled, the per thread context will not be released - // because the per thread hip graph needs to be maintained and replayed for - // the next run. - // The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end(): - // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), - // PerThreadContext won't be created and there is nothing to - // release. This didn't happen before because we always call - // GetPerThreadContext in OnRunStart. - if (!IsGraphCaptureEnabled() && - PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { - ReleasePerThreadContext(); - } - - return Status::OK(); -} - -bool ROCMExecutionProvider::IsGraphCaptureEnabled() const { - return info_.enable_hip_graph; -} - -bool ROCMExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { - return GetPerThreadContext().IsGraphCaptured(graph_annotation_id); -} - -Status ROCMExecutionProvider::ReplayGraph(int graph_annotation_id) { - return GetPerThreadContext().ReplayGraph(graph_annotation_id); -} - -namespace rocm { -// opset 1 to 9 -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyToHost); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, float, Cos); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, double, Cos); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, MLFloat16, Cos); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, float, Sin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, double, Sin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, MLFloat16, Sin); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 4, 10, Concat); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 8, Flatten); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, Squeeze); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, Identity); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, Dropout); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, Gather); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, Gemm); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, double, Gemm); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 8, float, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 8, double, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 8, MLFloat16, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, MatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, float, Elu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, double, Elu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, MLFloat16, Elu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, float, HardSigmoid); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, double, HardSigmoid); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, MLFloat16, HardSigmoid); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 15, float, LeakyRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 15, double, LeakyRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 15, MLFloat16, LeakyRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Relu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Relu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Relu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, float, Selu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, double, Selu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, MLFloat16, Selu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Sigmoid); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Sigmoid); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Sigmoid); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Softsign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Softsign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Softsign); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Tanh); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Tanh); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Tanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Softplus); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Softplus); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Softplus); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, Softmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, Softmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Softmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, LogSoftmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, float, Pow); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, double, Pow); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, PRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, PRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, PRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 15, float, PRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 15, double, PRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 15, MLFloat16, PRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, And); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Or); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Xor); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 7, Sum); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 12, Sum); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 11, Max); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Max); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 11, Min); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Min); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 10, bool, Equal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 10, int32_t, Equal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 10, int64_t, Equal); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 12, Expand); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int32_t, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int64_t, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int32_t, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int64_t, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, uint32_t, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, uint64_t, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, float, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, double, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, MLFloat16, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int32_t, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int64_t, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, uint32_t, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, uint64_t, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, float, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, double, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, MLFloat16, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int32_t, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int64_t, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, uint32_t, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, uint64_t, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, float, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, double, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, MLFloat16, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int32_t, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int64_t, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, uint32_t, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, uint64_t, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, float, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, double, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, MLFloat16, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, int8_t, Abs); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, int16_t, Abs); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, int32_t, Abs); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, int64_t, Abs); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, uint8_t, Abs); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, uint16_t, Abs); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, uint32_t, Abs); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, uint64_t, Abs); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Abs); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Abs); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Abs); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, int8_t, Neg); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, int16_t, Neg); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, int32_t, Neg); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, int64_t, Neg); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Neg); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Neg); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Neg); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Floor); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Floor); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Floor); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Ceil); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Ceil); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Ceil); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 10, float, Clip); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Reciprocal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Reciprocal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Reciprocal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Sqrt); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Sqrt); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Sqrt); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Log); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Log); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Log); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Exp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Exp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Exp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, Erf); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Erf); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Erf); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, bool, Not); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, float, LRN); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, double, LRN); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, LRN); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, Conv); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, Conv); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Conv); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, double, AveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, AveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, GlobalAveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, GlobalAveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalAveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 7, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 7, double, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 7, MLFloat16, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 9, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 9, double, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 9, MLFloat16, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, GlobalMaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalMaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, uint8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, float, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, double, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, MLFloat16, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, int8_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, int16_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, int32_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, int64_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, uint8_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, uint16_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, uint32_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, uint64_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, bool, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int8_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int16_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int32_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int64_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint8_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint16_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, bool, Cast); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, IsNaN); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, float, Pad); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, double, Pad); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, MLFloat16, Pad); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 4, Reshape); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 5, 12, Reshape); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, Shape); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, Size); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, Tile); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Tile); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, Transpose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, double, InstanceNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, MLFloat16, InstanceNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 13, float, RNN); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 13, double, RNN); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 13, MLFloat16, RNN); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 13, float, GRU); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 13, double, GRU); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 13, MLFloat16, GRU); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 13, float, LSTM); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 13, double, LSTM); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 13, MLFloat16, LSTM); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 9, int64_t, Slice); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, Compress); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, Flatten); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, Upsample); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, Upsample); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Upsample); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, int32_t, Upsample); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, uint8_t, Upsample); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 9, float, Upsample); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 9, double, Upsample); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 9, MLFloat16, Upsample); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 9, int32_t, Upsample); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 9, uint8_t, Upsample); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, Split); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, ConstantOfShape); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int8_t, Shrink); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int16_t, Shrink); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int32_t, Shrink); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int64_t, Shrink); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint8_t, Shrink); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint16_t, Shrink); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint32_t, Shrink); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint64_t, Shrink); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, Shrink); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, Shrink); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, Less); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, Less); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Less); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int32_t, Less); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int64_t, Less); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Less); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Less); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, Less); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Less); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Less); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, EyeLike); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, Scatter); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 15, MLFloat16, Where); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 15, float, Where); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 15, double_t, Where); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 15, int32_t, Where); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 15, int64_t, Where); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 15, uint8_t, Where); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, bool, NonZero); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint8_t, NonZero); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, NonZero); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, NonZero); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 9, TopK); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, If); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 8, Scan); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, Scan); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, Loop); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, DepthToSpace); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, SpaceToDepth); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, RandomNormal); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, RandomNormalLike); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, RandomUniform); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, RandomUniformLike); - -// opset 10 -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, AveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, double, AveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, AveragePool); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 11, Dropout); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, double, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, MaxPool); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, Resize); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, double, Resize); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, Resize); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, int32_t, Resize); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Resize); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, ReverseSequence); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, float, RoiAlign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, double, RoiAlign); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, float, ThresholdedRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, double, ThresholdedRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, TopK); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, Mod); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 19, IsInf); - -// opset 11 -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Compress); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Concat); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Flatten); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Gather); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, GatherElements); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, int64_t, GatherND); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, Gemm); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Gemm); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, If); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Loop); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, NonMaxSuppression); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Range); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 15, Scan); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, ScatterElements); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, Softmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, Softmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Softmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, LogSoftmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, LogSoftmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, LogSoftmax); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Split); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Squeeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, TopK); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, SequenceAt); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, SequenceConstruct); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, SequenceEmpty); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, SequenceLength); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, ConcatFromSequence); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, SequenceErase); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, SequenceInsert); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Conv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Conv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Conv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, ConvTranspose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, ConvTranspose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, ConvTranspose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, AveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, AveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, AveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, Resize); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, Resize); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Resize); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Resize); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, uint8_t, Resize); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, Clip); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, Pad); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, Pad); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Pad); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, bool, Equal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Equal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, Equal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, uint32_t, Equal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, uint64_t, Equal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, Equal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, Equal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Round); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Round); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Round); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, uint8_t, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 13, CumSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_int64_t_int64_t, OneHot); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_float_int64_t, OneHot); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int32_t_float_int32_t, OneHot); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_MLFloat16_int64_t, OneHot); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int32_t_MLFloat16_int32_t, OneHot); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, ScatterND); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); - -// OpSet 12 -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Clip); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, double, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, int8_t, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, uint8_t, MaxPool); - -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Pow); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int64_t, GatherND); - -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Dropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin); - -// OpSet 13 -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Pow); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int64_t, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint64_t, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, float, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, double, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, Add); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Clip); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int32_t, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int64_t, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint64_t, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, float, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, double, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int32_t, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int64_t, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint64_t, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, float, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, double, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int32_t, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int64_t, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint64_t, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, float, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, double, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, Div); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Abs); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Abs); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Abs); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Abs); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Abs); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Abs); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Abs); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Abs); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Abs); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Abs); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Abs); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Neg); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Neg); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Neg); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Neg); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Neg); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Neg); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Neg); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Neg); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Floor); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Floor); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Floor); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Ceil); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Ceil); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Ceil); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Reciprocal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Reciprocal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Reciprocal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sqrt); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sqrt); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sqrt); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Sqrt); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Log); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Log); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Log); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Exp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Exp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp); -// Add bf16 support for Exp in opset 13+ for phimm model -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Exp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Erf); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Erf); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf); -// Add bf16 support for Erf in opset 13+ for phimm model -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Erf); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Expand); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Sum); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Max); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Min); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Greater); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Greater); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Greater); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Greater); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Greater); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Greater); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Greater); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Less); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Less); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Less); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Less); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Less); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Less); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Less); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, NonZero); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, NonZero); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, NonZero); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, NonZero); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, NonZero); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, NonZero); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, float, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, double, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, MLFloat16, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int16_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int32_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int64_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint16_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Cast); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, bool, Cast); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Reshape); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Softmax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Softmax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Softmax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, LogSoftmax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, LogSoftmax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LogSoftmax); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, Split); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Squeeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Unsqueeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Concat); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Gather); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, GatherElements); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 19, IsNaN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, MatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, MatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, float, Relu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, double, Relu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, Relu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sigmoid); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sigmoid); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sigmoid); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Tanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Tanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Tanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Dropout); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Resize); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Resize); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Resize); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, int32_t, Resize); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, uint8_t, Resize); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, If); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, Loop); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Flatten); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, LRN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, LRN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Identity); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterND); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Pad); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Pad); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, bool, Pad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Mul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, BFloat16, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Relu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Sigmoid); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Tanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Mod); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ArgMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ArgMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin); - -// OpSet 14 -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, Relu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, double, Relu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, MLFloat16, Relu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int32_t, Add); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int64_t, Add); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint32_t, Add); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint64_t, Add); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, Add); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, double, Add); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, MLFloat16, Add); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int32_t, Sub); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int64_t, Sub); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint32_t, Sub); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint64_t, Sub); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, Sub); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, double, Sub); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, MLFloat16, Sub); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int32_t, Mul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int64_t, Mul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint32_t, Mul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint64_t, Mul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, Mul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, double, Mul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, MLFloat16, Mul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int32_t, Div); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int64_t, Div); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint32_t, Div); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint64_t, Div); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, Div); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, double, Div); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, MLFloat16, Div); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, 18, Identity); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, 18, Reshape); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, RNN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, double, RNN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, MLFloat16, RNN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, GRU); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, double, GRU); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, MLFloat16, GRU); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, LSTM); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, double, LSTM); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( - kRocmExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( - kRocmExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( - kRocmExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMin); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, Trilu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Add); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Sub); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Mul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Div); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Relu); - -// OpSet 15 -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, Pow); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, float, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, double, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, 18, Shape); - -// Opset 16 -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LeakyRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LeakyRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LeakyRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, PRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, PRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 18, Scan); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, Where); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, BFloat16, Where); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, Where); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double_t, Where); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, int32_t, Where); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, int64_t, Where); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, uint8_t, Where); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, int32_t, GreaterOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, int64_t, GreaterOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, uint32_t, GreaterOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, uint64_t, GreaterOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, GreaterOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, GreaterOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, GreaterOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, int32_t, LessOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, int64_t, LessOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, uint32_t, LessOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, uint64_t, LessOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LessOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, GridSample); - -// Opset 17 -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, double, LayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, BFloat16, LayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization); - -// Opset 18 -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, uint8_t, Resize); - -// Opset 19 -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, float, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, double, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, MLFloat16, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, BFloat16, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, int8_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, int16_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, int32_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, int64_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, uint8_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, uint16_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, uint32_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, uint64_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, bool, Cast); -// #if !defined(DISABLE_FLOAT8_TYPES) -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E4M3FN, Cast); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E5M2, Cast); -// #endif - -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, float, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, float, DequantizeLinear); -// #if !defined(DISABLE_FLOAT8_TYPES) -// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, float, DequantizeLinear); -// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, float, DequantizeLinear); -// #endif -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, MLFloat16, DequantizeLinear); -// #if !defined(DISABLE_FLOAT8_TYPES) -// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, MLFloat16, DequantizeLinear); -// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, DequantizeLinear); -// #endif - -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Identity); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, If); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Loop); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, float, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, float, QuantizeLinear); -// #if !defined(DISABLE_FLOAT8_TYPES) -// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, float, QuantizeLinear); -// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, float, QuantizeLinear); -// #endif -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, MLFloat16, QuantizeLinear); -// #if !defined(DISABLE_FLOAT8_TYPES) -// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, MLFloat16, QuantizeLinear); -// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, QuantizeLinear); -// #endif -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Reshape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Scan); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Shape); - -// Opset 20 -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, float, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, double, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsNaN); - -// Opset 21. -// TODO(fajin): support other quantized types -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, float, DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, float, DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, float, DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, float, DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, MLFloat16, DequantizeLinear); -// #if !defined(DISABLE_FLOAT8_TYPES) -// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, float, DequantizeLinear); -// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, float, DequantizeLinear); -// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, DequantizeLinear); -// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, DequantizeLinear); -// #endif - -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, float, QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, float, QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, float, QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, float, QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, MLFloat16, QuantizeLinear); -// #if !defined(DISABLE_FLOAT8_TYPES) -// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, float, QuantizeLinear); -// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, float, QuantizeLinear); -// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, QuantizeLinear); -// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, QuantizeLinear); -// #endif - -template <> -KernelCreateInfo BuildKernelCreateInfo() { - return {}; -} - -// clang-format off -static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { - static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // opset 10 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // opset 11 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // OpSet 12 - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // OpSet 13 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // OpSet 14 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // OpSet 15 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // Opset 16 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // Opset 17 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // Opset 18 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // Opset 19 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, -//#if !defined(DISABLE_FLOAT8_TYPES) -// BuildKernelCreateInfo, -// BuildKernelCreateInfo, -//#endif - - BuildKernelCreateInfo, - BuildKernelCreateInfo, -//#if !defined(DISABLE_FLOAT8_TYPES) -// BuildKernelCreateInfo, -// BuildKernelCreateInfo, -//#endif - BuildKernelCreateInfo, - BuildKernelCreateInfo, -//#if !defined(DISABLE_FLOAT8_TYPES) -// BuildKernelCreateInfo, -// BuildKernelCreateInfo, -//#endif - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, -//#if !defined(DISABLE_FLOAT8_TYPES) -// BuildKernelCreateInfo, -// BuildKernelCreateInfo, -//#endif - BuildKernelCreateInfo, - BuildKernelCreateInfo, -//#if !defined(DISABLE_FLOAT8_TYPES) -// BuildKernelCreateInfo, -// BuildKernelCreateInfo, -//#endif - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // opset 20 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // opset 21 - // TODO(fajin): support other quantized types - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, -//#if !defined(DISABLE_FLOAT8_TYPES) -// BuildKernelCreateInfo, -// BuildKernelCreateInfo, -// BuildKernelCreateInfo, -// BuildKernelCreateInfo, -//#endif - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, -//#if !defined(DISABLE_FLOAT8_TYPES) -// BuildKernelCreateInfo, -// BuildKernelCreateInfo, -// BuildKernelCreateInfo, -// BuildKernelCreateInfo, -//#endif - }; - - for (auto& function_table_entry : function_table) { - KernelCreateInfo info = function_table_entry(); - if (info.kernel_def != nullptr) { // filter disabled entries where type is void - ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); - } - } - -#ifndef DISABLE_CONTRIB_OPS - ORT_RETURN_IF_ERROR(::onnxruntime::contrib::rocm::RegisterRocmContribKernels(kernel_registry)); -#endif - -#ifdef ENABLE_TRAINING_OPS - ORT_RETURN_IF_ERROR(::onnxruntime::rocm::RegisterRocmTrainingKernels(kernel_registry)); -#endif - - return Status::OK(); -} -// clang-format on - -} // namespace rocm - -static std::shared_ptr s_kernel_registry; - -void InitializeRegistry() { - s_kernel_registry = KernelRegistry::Create(); - ORT_THROW_IF_ERROR(rocm::RegisterRocmKernels(*s_kernel_registry)); -} - -void DeleteRegistry() { - s_kernel_registry.reset(); -} - -std::shared_ptr ROCMExecutionProvider::GetKernelRegistry() const { - return s_kernel_registry; -} - -static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) { - const auto& node_attributes = node.GetAttributes(); - // Check attributes - for (auto& attr : node_attributes) { - auto& attr_name = attr.first; - auto& attr_value = attr.second; - - // string is not supported - if ("to" == attr_name && ::ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT == attr_value.type()) { - auto to_type = attr_value.i(); - if (to_type == ::ONNX_NAMESPACE::TensorProto_DataType_STRING) - return true; - } - } - - return false; -} - -static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) { - // Opset 12 introduced the attribute "select_last_index" - if (node.SinceVersion() >= 12) { - const auto& node_attributes = node.GetAttributes(); - - for (auto& attr : node_attributes) { - auto& attr_name = attr.first; - auto& attr_value = attr.second; - - // It is not supported to pick the last index in case of encountering duplicate max values. - if ("select_last_index" == attr_name) { - if (attr_value.i() != 0) { - return true; - } - } - } - } - - return false; -} -std::unique_ptr ROCMExecutionProvider::GetDataTransfer() const { - return std::make_unique(); -} - -std::vector> -ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, - IResourceAccountant* /* resource_accountant */) const { - InlinedVector candidates; - // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. - InlinedVector tentative_nodes; - const logging::Logger& logger = *GetLogger(); - for (auto& node_index : graph.GetNodesInTopologicalOrder()) { - const auto* p_node = graph.GetNode(node_index); - if (p_node == nullptr) - continue; - - const auto& node = *p_node; - if (!node.GetExecutionProviderType().empty()) { - if (node.GetExecutionProviderType() == kRocmExecutionProvider) { - candidates.push_back(node.Index()); - } - continue; - } - - const KernelCreateInfo* rocm_kernel_def = kernel_lookup.LookUpKernel(node); - // none of the provided registries has a ROCM kernel for this node - if (rocm_kernel_def == nullptr) { - LOGS(logger, INFO) << "ROCM kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); - continue; - } - - bool not_supported = false; - bool force_inside = false; // for some compute heavy ops, we'll force it to run inside ROCM - if ("LSTM" == node.OpType() || - "RNN" == node.OpType() || - "GRU" == node.OpType()) { - not_supported = true; - force_inside = !not_supported; - } else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) { - not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node); - force_inside = !not_supported; - } else if ("Cast" == node.OpType()) { - not_supported = CastNeedFallbackToCPU(node); - // cast is not compute heavy, and may be placed outside - } - - if (!force_inside && not_supported) { - if (not_supported) { - LOGS(logger, WARNING) << "ROCM kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); - } - } else { - tentative_nodes.push_back(node.Index()); - candidates.push_back(node.Index()); - } - } - - // For ROCM EP, exclude the subgraph that is preferred to be placed in CPU - // These are usually shape related computation subgraphs - // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); - std::vector> result; - for (auto& node_index : candidates) { - if (cpu_nodes.count(node_index) > 0) - continue; - - auto sub_graph = IndexedSubGraph::Create(); - sub_graph->Nodes().push_back(node_index); - result.push_back(ComputeCapability::Create(std::move(sub_graph))); - } - return result; -} - -void ROCMExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const { - // This allocator must be the same to the allocator - // used in AllocateBufferOnCPUPinned. - auto allocator = allocators[GetOrtDeviceByMemType(OrtMemTypeCPU)]; - RegisterRocmStreamHandles(stream_handle_registry, - OrtDevice::GPU, - allocator, - !IsGraphCaptureEnabled(), - stream_, - use_ep_level_unified_stream_, - GetPerThreadContext().MiopenHandle(), - GetPerThreadContext().HipblasHandle(), - info_); -} - -OrtDevice ROCMExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { - if (mem_type == OrtMemTypeCPUInput) - return OrtDevice(); - if (mem_type == OrtMemTypeCPUOutput) - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, - 0 /*CPU device id always be 0*/); - return default_device_; -} - -std::vector ROCMExecutionProvider::CreatePreferredAllocators() { - AllocatorCreationInfo pinned_memory_info( - [](OrtDevice::DeviceId) { - return std::make_unique(HIP_PINNED); - }, - // TODO: should we use info_.device_id instead of DEFAULT_CPU_ALLOCATOR_DEVICE_ID? - // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb - // says the pinned memory allocated by cudaMallocHost is associated with a specific device, so it may be more - // correct to use the GPU device id, unless we wanted to share the pinned memory allocator across devices, - // at the risk the lifetime isn't managed correctly if one of those devices go away. - 0); - return std::vector{ - CreateRocmAllocator(info_.device_id, info_.gpu_mem_limit, info_.arena_extend_strategy, - info_.external_allocator_info, info_.default_memory_arena_cfg), - CreateAllocator(pinned_memory_info), - }; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h deleted file mode 100644 index 2baaf2ff1a886..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/framework/arena_extend_strategy.h" -#include "core/framework/execution_provider.h" -#include -#include "core/providers/rocm/rocm_execution_provider_info.h" -#include "core/providers/rocm/rocm_graph.h" -#include "core/providers/rocm/rocm_pch.h" -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#include "core/providers/rocm/shared_inc/rocm_call.h" -#include "core/providers/rocm/tunable/rocm_tuning_context.h" - -namespace onnxruntime { - -void RunOnUnload(std::function function); - -// Logical device representation. -class ROCMExecutionProvider : public IExecutionProvider { - public: - explicit ROCMExecutionProvider(const ROCMExecutionProviderInfo& info); - virtual ~ROCMExecutionProvider(); - - Status Sync() const override; - - Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - - Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; - - const void* GetExecutionHandle() const noexcept override { - // The ROCM interface does not return anything interesting. - return nullptr; - } - - hipblasHandle_t PerThreadDefaultHipblasHandle() { - return GetPerThreadContext().HipblasHandle(); - } - - miopenHandle_t PerThreadDefaultMiopenHandle() { - return GetPerThreadContext().MiopenHandle(); - } - - hipStream_t ComputeStream() { - // this will return the ROCM EP level stream which can differ from the actual compute tasks stream - // the compute task stream is supplied within OpKernelContext during inference - return stream_; - } - - template - const T* GetConstOnes(size_t count, hipStream_t stream) { - return GetPerThreadContext().template GetConstOnes(count, stream); - } - - std::shared_ptr GetKernelRegistry() const override; - std::unique_ptr GetDataTransfer() const override; - - std::vector> GetCapability( - const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, - IResourceAccountant* /* resource_accountant */) const override; - - int GetDeviceId() const override { return info_.device_id; } - const hipDeviceProp_t& GetDeviceProp() const { return device_prop_; }; - int GetMiopenConvExhaustiveSearch() const { return info_.miopen_conv_exhaustive_search; } - bool DoCopyOnDefaultStream() const { return info_.do_copy_in_default_stream; } - bool GetMiopenConvUseMaxWorkspace() const { return info_.miopen_conv_use_max_workspace; } - - ProviderOptions GetProviderOptions() const override { - return ROCMExecutionProviderInfo::ToProviderOptions(info_); - } - - static AllocatorPtr CreateRocmAllocator(OrtDevice::DeviceId device_id, size_t rocm_mem_limit, ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg); - - ITuningContext* GetTuningContext() const override; - - std::unique_ptr GetProfiler() override; - - bool IsGraphCaptureEnabled() const override; - bool IsGraphCaptured(RocmGraphAnnotation_t graph_annotation_id) const override; - Status ReplayGraph(RocmGraphAnnotation_t graph_annotation_id) override; - void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; - OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; - std::vector CreatePreferredAllocators() override; - - private: - ROCMExecutionProviderInfo info_; - hipDeviceProp_t device_prop_; - bool external_stream_ = false; - // only used when set user external stream or hip graph - hipStream_t stream_ = nullptr; - - bool use_ep_level_unified_stream_ = false; - - // the tuning context might be altered when calling into a TunableOp - mutable rocm::tunable::RocmTuningContext tuning_context_; - - class PerThreadContext final { - public: - PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, size_t rocm_mem_limit, ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg); - ~PerThreadContext(); - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); - - hipblasHandle_t HipblasHandle() const { - return hipblas_handle_; - } - - miopenHandle_t MiopenHandle() const { - return miopen_handle_; - } - - template - const T* GetConstOnes(size_t count, hipStream_t stream) { - constexpr bool is_float = std::is_same::value; - constexpr bool is_double = std::is_same::value; - constexpr bool is_half = std::is_same::value; - constexpr bool is_BFloat16 = std::is_same::value; - if (is_float) { - if (!constant_ones_float_) { - constant_ones_float_ = rocm::CreateConstantOnes(); - } - return reinterpret_cast(constant_ones_float_->GetBuffer(stream, count)); - } else if (is_double) { - if (!constant_ones_double_) { - constant_ones_double_ = rocm::CreateConstantOnes(); - } - return reinterpret_cast(constant_ones_double_->GetBuffer(stream, count)); - } else if (is_half) { - if (!constant_ones_half_) { - constant_ones_half_ = rocm::CreateConstantOnes(); - } - return reinterpret_cast(constant_ones_half_->GetBuffer(stream, count)); - } else if (is_BFloat16) { - if (!constant_ones_bfloat16_) { - constant_ones_bfloat16_ = rocm::CreateConstantOnes(); - } - return reinterpret_cast(constant_ones_bfloat16_->GetBuffer(stream, count)); - } else { - return nullptr; - } - } - - bool IsGraphCaptureAllowed(RocmGraphAnnotation_t hip_graph_annotation_id) const; - bool IsGraphCaptureAllowedOnRun(RocmGraphAnnotation_t hip_graph_annotation_id) const; - void CaptureBegin(RocmGraphAnnotation_t hip_graph_annotation_id); - void CaptureEnd(RocmGraphAnnotation_t hip_graph_annotation_id); - bool IsGraphCaptured(RocmGraphAnnotation_t hip_graph_annotation_id) const; - RocmGraphAnnotation_t GetRocmGraphAnnotationId(const onnxruntime::RunOptions& run_options) const; - Status ReplayGraph(RocmGraphAnnotation_t hip_graph_annotation_id); - void IncrementRegularRunCountBeforeGraphCapture(RocmGraphAnnotation_t hip_graph_annotation_id); - - private: - hipblasHandle_t hipblas_handle_ = nullptr; - miopenHandle_t miopen_handle_ = nullptr; - - std::unique_ptr> constant_ones_float_; - std::unique_ptr> constant_ones_double_; - std::unique_ptr> constant_ones_half_; - std::unique_ptr> constant_ones_bfloat16_; - - // Hip graph with multi threads will be supported in the future, so hip_graph_ - // is put under PerThreadContext. - ROCMGraph hip_graph_; - // Map of graph id to regular_run_count_before_graph_capture - std::unordered_map graph_id_to_run_count_; - - // There is chance that the second regular run allocates GPU memory for causes like: - // (1) memory pattern is enabled. (2) arena allocation for stream. - // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs - // to allocate enough memory in Arena before graph capturing. - const int min_num_runs_before_hip_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations. - }; - - using PerThreadContextMap = std::unordered_map>; - // thread local PerThreadContext cache - - struct ContextCacheHolder { - ContextCacheHolder() { - // Keep a weak pointer to the object, if the weak pointer can be locked, then the shared pointer is still around, so we can reset it - RunOnUnload([&, weak_p_ = std::weak_ptr(p)] { - if (auto lock = weak_p_.lock()) - p.reset(); - }); - } - std::shared_ptr p = std::make_shared(); - }; - - static const std::shared_ptr& PerThreadContextCache() { - thread_local const ContextCacheHolder per_thread_context_cache; - return per_thread_context_cache.p; - } - - struct PerThreadContextState { - // contexts that are currently active - std::set, std::owner_less>> active_contexts; - // contexts available for reuse - std::vector> retired_context_pool; - // weak references to thread local caches from which this ROCMExecutionProvider instance's entry should be removed - // upon destruction - std::set, std::owner_less>> - caches_to_update_on_destruction; - // synchronizes access to PerThreadContextState members - std::mutex mutex; - }; - - // The execution provider maintains the PerThreadContexts in this structure. - // Synchronization is required to update the contained structures. - // On the other hand, access to an individual PerThreadContext is assumed to be from a single thread at a time, - // so synchronization is not required for that. - mutable PerThreadContextState context_state_; - - PerThreadContext& GetPerThreadContext() const; - void ReleasePerThreadContext() const; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc deleted file mode 100644 index 3cb826437a54f..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_execution_provider_info.h" - -#include "core/common/make_string.h" -#include "core/common/parse_string.h" -#include "core/framework/provider_options_utils.h" -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace rocm { -namespace provider_option_names { -constexpr const char* kDeviceId = "device_id"; -constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; -constexpr const char* kUserComputeStream = "user_compute_stream"; -constexpr const char* kMemLimit = "gpu_mem_limit"; -constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; -constexpr const char* kMiopenConvExhaustiveSearch = "miopen_conv_exhaustive_search"; -constexpr const char* kDoCopyInDefaultStream = "do_copy_in_default_stream"; -constexpr const char* kGpuExternalAlloc = "gpu_external_alloc"; -constexpr const char* kGpuExternalFree = "gpu_external_free"; -constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache"; -constexpr const char* kMiopenConvUseMaxWorkspace = "miopen_conv_use_max_workspace"; -constexpr const char* kEnableHipGraph = "enable_hip_graph"; -constexpr const char* kTunableOpEnable = "tunable_op_enable"; -constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; -constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; -} // namespace provider_option_names -} // namespace rocm - -const EnumNameMapping arena_extend_strategy_mapping{ - {ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"}, - {ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"}, -}; - -ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { - ROCMExecutionProviderInfo info{}; - void* alloc = nullptr; - void* free = nullptr; - void* empty_cache = nullptr; - void* user_compute_stream = nullptr; - ORT_THROW_IF_ERROR( - ProviderOptionsParser{} - .AddValueParser( - rocm::provider_option_names::kDeviceId, - [&info](const std::string& value_str) -> Status { - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); - int num_devices{}; - HIP_RETURN_IF_ERROR(hipGetDeviceCount(&num_devices)); - ORT_RETURN_IF_NOT( - 0 <= info.device_id && info.device_id < num_devices, - "Invalid device ID: ", info.device_id, - ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); - return Status::OK(); - }) - .AddAssignmentToReference(rocm::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) - .AddValueParser( - rocm::provider_option_names::kUserComputeStream, - [&user_compute_stream](const std::string& value_str) -> Status { - size_t address; - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - user_compute_stream = reinterpret_cast(address); - return Status::OK(); - }) - .AddValueParser( - rocm::provider_option_names::kGpuExternalAlloc, - [&alloc](const std::string& value_str) -> Status { - size_t address; - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - alloc = reinterpret_cast(address); - return Status::OK(); - }) - .AddValueParser( - rocm::provider_option_names::kGpuExternalFree, - [&free](const std::string& value_str) -> Status { - size_t address; - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - free = reinterpret_cast(address); - return Status::OK(); - }) - .AddValueParser( - rocm::provider_option_names::kGpuExternalEmptyCache, - [&empty_cache](const std::string& value_str) -> Status { - size_t address; - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - empty_cache = reinterpret_cast(address); - return Status::OK(); - }) - .AddAssignmentToReference(rocm::provider_option_names::kMemLimit, info.gpu_mem_limit) - .AddAssignmentToEnumReference( - rocm::provider_option_names::kArenaExtendStrategy, - arena_extend_strategy_mapping, info.arena_extend_strategy) - .AddAssignmentToReference( - rocm::provider_option_names::kMiopenConvExhaustiveSearch, - info.miopen_conv_exhaustive_search) - .AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) - .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace) - .AddAssignmentToReference(rocm::provider_option_names::kEnableHipGraph, info.enable_hip_graph) - .AddValueParser( - rocm::provider_option_names::kTunableOpEnable, - [&info](const std::string& value_str) -> Status { - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.enable)); - return Status::OK(); - }) - .AddValueParser( - rocm::provider_option_names::kTunableOpTuningEnable, - [&info](const std::string& value_str) -> Status { - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.tuning_enable)); - return Status::OK(); - }) - .AddValueParser( - rocm::provider_option_names::kTunableOpMaxTuningDurationMs, - [&info](const std::string& value_str) -> Status { - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.max_tuning_duration_ms)); - return Status::OK(); - }) - .Parse(options)); - - ROCMExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; - info.external_allocator_info = alloc_info; - - info.user_compute_stream = user_compute_stream; - info.has_user_compute_stream = (user_compute_stream != nullptr); - - return info; -} - -ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecutionProviderInfo& info) { - const ProviderOptions options{ - {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, - {rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, - {rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, - {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, - {rocm::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, - {rocm::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, - {rocm::provider_option_names::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.empty_cache))}, - {rocm::provider_option_names::kArenaExtendStrategy, - EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, - {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, - {rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, - {rocm::provider_option_names::kMiopenConvUseMaxWorkspace, MakeStringWithClassicLocale(info.miopen_conv_use_max_workspace)}, - {rocm::provider_option_names::kEnableHipGraph, MakeStringWithClassicLocale(info.enable_hip_graph)}, - {rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)}, - {rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, - {rocm::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, - }; - - return options; -} - -ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const OrtROCMProviderOptions& info) { - const ProviderOptions options{ - {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, - {rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, - {rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, - {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, - {rocm::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast(info.arena_extend_strategy))}, - {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, - {rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, - {rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op_enable)}, - {rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op_tuning_enable)}, - {rocm::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op_max_tuning_duration_ms)}, - }; - - return options; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h deleted file mode 100644 index c245b18057ca7..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/common/hash_combine.h" -#include "core/framework/arena_extend_strategy.h" -#include "core/framework/ortdevice.h" -#include "core/framework/provider_options.h" -#include "core/session/onnxruntime_c_api.h" - -namespace onnxruntime { -// Information needed to construct ROCM execution providers. -struct ROCMExecutionProviderExternalAllocatorInfo { - void* alloc{nullptr}; - void* free{nullptr}; - void* empty_cache{nullptr}; - - ROCMExecutionProviderExternalAllocatorInfo() { - alloc = nullptr; - free = nullptr; - empty_cache = nullptr; - } - - ROCMExecutionProviderExternalAllocatorInfo(void* a, void* f, void* e) { - alloc = a; - free = f; - empty_cache = e; - } - - bool UseExternalAllocator() const { - return (alloc != nullptr) && (free != nullptr); - } -}; - -namespace rocm { -struct TunableOpInfo { - bool enable{false}; - bool tuning_enable{false}; - int max_tuning_duration_ms{}; -}; -} // namespace rocm - -struct ROCMExecutionProviderInfo { - OrtDevice::DeviceId device_id{0}; - size_t gpu_mem_limit{std::numeric_limits::max()}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) - ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) - bool miopen_conv_exhaustive_search{false}; - bool do_copy_in_default_stream{true}; - bool has_user_compute_stream{false}; - void* user_compute_stream{nullptr}; - // The following OrtArenaCfg instance only characterizes the behavior of the default memory - // arena allocator and not any other auxiliary allocator that may also be part of the ROCM EP. - // For example, auxiliary allocators `HIP_PINNED` and `HIP_CPU` will not be configured using this - // arena config. - OrtArenaCfg* default_memory_arena_cfg{nullptr}; - ROCMExecutionProviderExternalAllocatorInfo external_allocator_info{}; - - // By default, try to use as much as possible memory for algo search. - // If set to false, use fix workspace size (32M) for Conv algo search, the final algo might not be the best. - bool miopen_conv_use_max_workspace{true}; - - bool enable_hip_graph{false}; - - rocm::TunableOpInfo tunable_op{}; - - static ROCMExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); - static ProviderOptions ToProviderOptions(const ROCMExecutionProviderInfo& info); - static ProviderOptions ToProviderOptions(const OrtROCMProviderOptions& info); -}; -} // namespace onnxruntime - -template <> -struct std::hash<::onnxruntime::ROCMExecutionProviderInfo> { - size_t operator()(const ::onnxruntime::ROCMExecutionProviderInfo& info) const { - size_t value{0xbc9f1d34}; // seed - - // Bits: device_id (16), arena_extend_strategy/miopen_conv_exhaustive_search (reserved 2), boolean options (1 each) - size_t data = static_cast(info.device_id) ^ - (static_cast(info.arena_extend_strategy) << 16) ^ - (static_cast(info.miopen_conv_exhaustive_search) << 18) ^ - (static_cast(info.do_copy_in_default_stream) << 20) ^ - (static_cast(info.has_user_compute_stream) << 21) ^ - (static_cast(info.miopen_conv_use_max_workspace) << 22) ^ - (static_cast(info.enable_hip_graph) << 23) ^ - (static_cast(info.tunable_op.enable) << 24) ^ - (static_cast(info.tunable_op.tuning_enable) << 25); - onnxruntime::HashCombine(data, value); - - onnxruntime::HashCombine(info.gpu_mem_limit, value); - onnxruntime::HashCombine(info.tunable_op.max_tuning_duration_ms, value); - - // Memory pointers - onnxruntime::HashCombine(reinterpret_cast(info.user_compute_stream), value); - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.alloc), value); - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.free), value); - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.empty_cache), value); - - // The default memory arena cfg is not used in hashing right now. - return value; - } -}; diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h deleted file mode 100644 index 933a72122e7f9..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/backward_guard.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/rocm_execution_provider.h" -#include "core/providers/rocm/rocm_fwd.h" -#include "core/providers/rocm/rocm_stream_handle.h" - -namespace onnxruntime { -namespace rocm { - -// ----------------------------------------------------------------------- -// Base class for HIP kernels -// ----------------------------------------------------------------------- -class RocmKernel : public OpKernel { - public: - explicit RocmKernel(const OpKernelInfo& info) - : OpKernel(info), - // Is this OK to have a non-const execution provider? - provider_(const_cast(static_cast(info.GetExecutionProvider()))) { - } - - Status Compute(OpKernelContext* p_op_kernel_context) const override { - Status s; - auto is_backward_pass = Info().GetAttrOrDefault("__backwardpass", 0); - if (is_backward_pass) { - BackwardPassGuard guard; - s = ComputeInternal(p_op_kernel_context); - } else { - s = ComputeInternal(p_op_kernel_context); - } - // use this to precisely locate the node where ROCM failure comes from - // if (hipSuccess != hipDeviceSynchronize()) - // __debugbreak(); - if (s.IsOK()) { - auto err = hipGetLastError(); - if (err != hipSuccess) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "HIP error ", hipGetErrorName(err), ":", hipGetErrorString(err)); - } - } - return s; - } - - virtual Status ComputeInternal(OpKernelContext* p_op_kernel_context) const = 0; - - template - inline IAllocatorUniquePtr GetScratchBuffer(size_t count_or_bytes, onnxruntime::Stream* stream) const { - if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), count_or_bytes, false, stream, WaitRocmNotificationOnDevice); - } - - // Different from GetScratchBuffer which use IAllocator::Alloc() to allocate memory, - // this GetTransientScratchBuffer will call IAllocator::Reserve() to allocate memory. - // IAllocator::Reserve() optionally implement some allocation logic that by-passes any arena-based - // logic (or similar for different allocator) that may be housed in the Alloc() implementation. - template - inline IAllocatorUniquePtr GetTransientScratchBuffer(size_t count_or_bytes) const { - if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), count_or_bytes, true); - } - - inline void AddDeferredReleaseCPUPtr(void* p, onnxruntime::Stream* ort_stream) const { - ORT_ENFORCE(ort_stream->GetDevice().Type() == OrtDevice::GPU); - auto* rocm_ep_stream = static_cast(ort_stream); - rocm_ep_stream->EnqueDeferredCPUBuffer(p); - } - - template - inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t count_or_bytes) const { - if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeCPU), count_or_bytes); - } - - const hipDeviceProp_t& GetDeviceProp() const { return provider_->GetDeviceProp(); } - - inline hipStream_t Stream(OpKernelContext* ctx) const { - auto* stream = ctx->GetComputeStream(); - return stream ? static_cast(stream->GetHandle()) : nullptr; - } - - inline miopenHandle_t GetMiopenHandle(OpKernelContext* ctx) const { - return GetMiopenHandle(static_cast(ctx->GetComputeStream())); - } - - static inline miopenHandle_t GetMiopenHandle(onnxruntime::RocmStream* stream) { - return stream->miopen_handle_; - } - - inline hipblasHandle_t GetHipblasHandle(OpKernelContext* ctx) const { - return GetHipblasHandle(static_cast(ctx->GetComputeStream())); - } - - static inline hipblasHandle_t GetHipblasHandle(onnxruntime::RocmStream* stream) { - return stream->hipblas_handle_; - } - - bool UseTF32() const { - return false; - } - - tunable::RocmTuningContext* GetTuningContext() const { - return static_cast(provider_->GetTuningContext()); - } - - // To support hipMemcpyAsync, the cpu memory should be allocated in pinned memory - // and it can only be released after the copy has finished - template - class RocmAsyncBuffer { - public: - RocmAsyncBuffer(const RocmKernel* op_kernel) : gpu_copy_(nullptr), count_(0), op_kernel_(op_kernel) {} - - RocmAsyncBuffer(const RocmKernel* op_kernel, size_t count) : RocmAsyncBuffer(op_kernel) { - AllocCpuPtr(count); - } - - RocmAsyncBuffer(const RocmKernel* op_kernel, const T& value, size_t count) - : RocmAsyncBuffer(op_kernel, count) { - T* p = CpuPtr(); - for (size_t i = 0; i != count; ++i) { - *p++ = value; - } - } - - RocmAsyncBuffer(const RocmKernel* op_kernel, gsl::span vec) : RocmAsyncBuffer(op_kernel, vec.size()) { - memcpy(CpuPtr(), vec.data(), vec.size() * sizeof(T)); - } - - void AllocCpuPtr(size_t count) { - cpu_pinned_copy_ = op_kernel_->AllocateBufferOnCPUPinned(count); - if (cpu_pinned_copy_ == nullptr) - throw std::runtime_error("alloc failed"); - count_ = count; - } - - Status CopyToGpu(onnxruntime::Stream* stream) { - if (cpu_pinned_copy_) { - gpu_copy_ = op_kernel_->GetScratchBuffer(count_, stream); - hipStream_t rocm_stream = stream ? static_cast(stream->GetHandle()) : nullptr; - HIP_RETURN_IF_ERROR(hipMemcpyAsync(gpu_copy_.get(), cpu_pinned_copy_.get(), count_ * sizeof(T), hipMemcpyHostToDevice, - rocm_stream)); - op_kernel_->AddDeferredReleaseCPUPtr(cpu_pinned_copy_.release(), stream); - } - return Status::OK(); - } - - T* CpuPtr() const { - return cpu_pinned_copy_.get(); - } - - gsl::span CpuSpan() const { - return gsl::span(CpuPtr(), count_); - } - - T* GpuPtr() const { - return gpu_copy_.get(); - } - - size_t count() const { - return count_; - } - - protected: - IAllocatorUniquePtr gpu_copy_; - IAllocatorUniquePtr cpu_pinned_copy_; - size_t count_; - const RocmKernel* op_kernel_; - }; - - inline hipblasHandle_t DefaultHipblasHandle() const { - return provider_->PerThreadDefaultHipblasHandle(); - } - - inline miopenHandle_t DefaultMiopenHandle() const { - return provider_->PerThreadDefaultMiopenHandle(); - } - - inline hipStream_t DefaultHipStream() const { - // this will return the ROCM EP level stream which can differ from the actual compute tasks stream - // the compute task stream is supplied within OpKernelContext during inference - return provider_->ComputeStream(); - } - - inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); - return gpu_data_transfer->CopyTensorAsync(src, dst, stream); - } - - protected: - template - inline const T* GetConstOnes(size_t count, hipStream_t stream) const { - return provider_->template GetConstOnes(count, stream); - } - - inline int GetDeviceId() const { return provider_->GetDeviceId(); } - - private: - ROCMExecutionProvider* provider_; -}; - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_pch.h b/onnxruntime/core/providers/rocm/rocm_pch.h deleted file mode 100644 index 9713e41e126bb..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_pch.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#if defined(_MSC_VER) -#pragma warning(push) -// hip_fp16.hpp(394,38): warning C4505: '__float2half_rz': unreferenced local function has been removed -#pragma warning(disable : 4505) -#endif - -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef ORT_USE_NCCL -#include -#endif - -#ifdef USE_HIPBLASLT -#include -#endif - -#if defined(_MSC_VER) -#pragma warning(pop) -#endif diff --git a/onnxruntime/core/providers/rocm/rocm_profiler.cc b/onnxruntime/core/providers/rocm/rocm_profiler.cc deleted file mode 100644 index de52f512c5229..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_profiler.cc +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#if defined(USE_ROCM) && defined(ENABLE_ROCM_PROFILING) - -#include -#include - -#include "core/providers/rocm/rocm_profiler.h" - -namespace onnxruntime { -namespace profiling { - -RocmProfiler::RocmProfiler() { - auto& manager = RoctracerManager::GetInstance(); - client_handle_ = manager.RegisterClient(); -} - -RocmProfiler::~RocmProfiler() { - auto& manager = RoctracerManager::GetInstance(); - manager.DeregisterClient(client_handle_); -} - -} // namespace profiling -} // namespace onnxruntime -#endif diff --git a/onnxruntime/core/providers/rocm/rocm_profiler.h b/onnxruntime/core/providers/rocm/rocm_profiler.h deleted file mode 100644 index 52c6d4ea05f99..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_profiler.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#include -#include - -#include "core/common/gpu_profiler_common.h" -#include "roctracer_manager.h" - -#if defined(USE_ROCM) && defined(ENABLE_ROCM_PROFILING) - -namespace onnxruntime { -namespace profiling { - -using Events = std::vector; - -class RocmProfiler final : public GPUProfilerBase { - public: - RocmProfiler(); - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(RocmProfiler); - ~RocmProfiler(); -}; - -} // namespace profiling -} // namespace onnxruntime - -#else - -namespace onnxruntime { -namespace profiling { - -class RocmProfiler final : public EpProfiler { - public: - RocmProfiler() = default; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(RocmProfiler); - ~RocmProfiler() {} - bool StartProfiling(TimePoint) override { return true; } - void EndProfiling(TimePoint, Events&) override {} - void Start(uint64_t) override {} - void Stop(uint64_t) override {} -}; - -} // namespace profiling -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc deleted file mode 100644 index 170a566d850b0..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_provider_factory.h" -#include "core/providers/rocm/rocm_provider_factory_creator.h" - -#include - -#include "core/providers/rocm/rocm_execution_provider.h" -#include "core/providers/rocm/rocm_execution_provider_info.h" -#include "core/providers/rocm/rocm_allocator.h" -#include "core/providers/rocm/gpu_data_transfer.h" -#include "core/providers/rocm/math/unary_elementwise_ops_impl.h" - -#if defined(USE_ROCM) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) && defined(ENABLE_TRAINING) -#include "orttraining/training_ops/rocm/communication/nccl_service.h" -#endif - -using namespace onnxruntime; - -namespace onnxruntime { - -#if defined(USE_ROCM) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) && defined(ENABLE_TRAINING) -namespace rocm { -rocm::INcclService& GetINcclService(); -} -#endif - -void InitializeRegistry(); -void DeleteRegistry(); - -struct ROCMProviderFactory : IExecutionProviderFactory { - ROCMProviderFactory(const ROCMExecutionProviderInfo& info) - : info_{info} {} - ~ROCMProviderFactory() override {} - - std::unique_ptr CreateProvider() override; - - private: - ROCMExecutionProviderInfo info_; -}; - -std::unique_ptr ROCMProviderFactory::CreateProvider() { - return std::make_unique(info_); -} - -struct ProviderInfo_ROCM_Impl final : ProviderInfo_ROCM { - OrtStatus* SetCurrentGpuDeviceId(_In_ int device_id) override { - int num_devices; - auto hip_err = ::hipGetDeviceCount(&num_devices); - if (hip_err != hipSuccess) { - return CreateStatus(ORT_FAIL, "Failed to set device id since hipGetDeviceCount failed."); - } - - if (device_id >= num_devices) { - std::ostringstream ostr; - ostr << "Invalid device id. Device id should be less than total number of devices (" << num_devices << ")"; - return CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str()); - } - - hip_err = hipSetDevice(device_id); - if (hip_err != hipSuccess) { - return CreateStatus(ORT_FAIL, "Failed to set device id."); - } - return nullptr; - } - - OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) override { - auto hip_err = hipGetDevice(device_id); - if (hip_err != hipSuccess) { - return CreateStatus(ORT_FAIL, "Failed to get device id."); - } - return nullptr; - } - - std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) override { - return std::make_unique(device_id, name); - } - - std::unique_ptr CreateROCMPinnedAllocator(const char* name) override { - return std::make_unique(name); - } - - std::unique_ptr CreateGPUDataTransfer() override { - return std::make_unique(); - } - - void rocm__Impl_Cast(void* stream, const int64_t* input_data, int32_t* output_data, size_t count) override { - return rocm::Impl_Cast(static_cast(stream), input_data, output_data, count); - } - - void rocm__Impl_Cast(void* stream, const int32_t* input_data, int64_t* output_data, size_t count) override { - return rocm::Impl_Cast(static_cast(stream), input_data, output_data, count); - } - - void rocm__Impl_Cast(void* stream, const double* input_data, float* output_data, size_t count) override { - return rocm::Impl_Cast(static_cast(stream), input_data, output_data, count); - } - - void rocm__Impl_Cast(void* stream, const float* input_data, double* output_data, size_t count) override { - return rocm::Impl_Cast(static_cast(stream), input_data, output_data, count); - } - - Status RocmCall_false(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) override { return RocmCall(hipError_t(retCode), exprString, libName, hipError_t(successCode), msg, file, line); } - void RocmCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) override { RocmCall(hipError_t(retCode), exprString, libName, hipError_t(successCode), msg, file, line); } - - void CopyGpuToCpu(void* dst_ptr, const void* src_ptr, const size_t size, const OrtMemoryInfo& dst_location, const OrtMemoryInfo& src_location) override { - ORT_ENFORCE(dst_location.device.Type() == OrtDevice::CPU); - - // Current ROCM device. - int device; - HIP_CALL_THROW(hipGetDevice(&device)); - - if (device != src_location.id) { - // Need to switch to the allocating device. - HIP_CALL_THROW(hipSetDevice(src_location.id)); - // Copy from GPU to CPU. - HIP_CALL_THROW(hipMemcpy(dst_ptr, src_ptr, size, hipMemcpyDeviceToHost)); - // Switch back to current device. - HIP_CALL_THROW(hipSetDevice(device)); - } else { - // Copy from GPU to CPU. - HIP_CALL_THROW(hipMemcpy(dst_ptr, src_ptr, size, hipMemcpyDeviceToHost)); - } - } - - // Used by slice_concatenate_test.cc and onnxruntime_pybind_state.cc - - void rocmMemcpy_HostToDevice(void* dst, const void* src, size_t count) override { - // hipMemcpy() operates on the default stream - HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice)); - - // To ensure that the copy has completed, invoke a stream sync for the default stream. - // For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. - // The function will return once the pageable buffer has been copied to the staging memory for DMA transfer - // to device memory, but the DMA to final destination may not have completed. - - HIP_CALL_THROW(hipStreamSynchronize(0)); - } - - // Used by onnxruntime_pybind_state.cc - void rocmMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override { - // For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. - HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); - } - - int hipGetDeviceCount() override { - int num_devices = 0; - HIP_CALL_THROW(::hipGetDeviceCount(&num_devices)); - return num_devices; - } - - void ROCMExecutionProviderInfo__FromProviderOptions(const ProviderOptions& options, ROCMExecutionProviderInfo& info) override { - info = ROCMExecutionProviderInfo::FromProviderOptions(options); - } - -#if defined(USE_ROCM) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) && defined(ENABLE_TRAINING) - rocm::INcclService& GetINcclService() override { - return rocm::GetINcclService(); - } -#endif - - std::shared_ptr CreateExecutionProviderFactory(const ROCMExecutionProviderInfo& info) override { - return std::make_shared(info); - } - - std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { - return ROCMExecutionProvider::CreateRocmAllocator(device_id, gpu_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); - } -} g_info; - -struct ROCM_Provider : Provider { - void* GetInfo() override { return &g_info; } - - std::shared_ptr CreateExecutionProviderFactory(const void* void_params) override { - auto params = reinterpret_cast(void_params); - - ROCMExecutionProviderInfo info{}; - info.device_id = gsl::narrow(params->device_id); - info.gpu_mem_limit = params->gpu_mem_limit; - info.arena_extend_strategy = static_cast(params->arena_extend_strategy); - info.miopen_conv_exhaustive_search = params->miopen_conv_exhaustive_search; - info.do_copy_in_default_stream = params->do_copy_in_default_stream != 0; - info.has_user_compute_stream = params->has_user_compute_stream != 0; - info.user_compute_stream = params->user_compute_stream; - info.default_memory_arena_cfg = params->default_memory_arena_cfg; - info.enable_hip_graph = params->enable_hip_graph != 0; - info.tunable_op.enable = params->tunable_op_enable; - info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; - info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; - - return std::make_shared(info); - } - - /** - * This function will be called by the C API UpdateROCMProviderOptions(). - * - * What this function does is equivalent to resetting the OrtROCMProviderOptions instance with - * default ROCMExecutionProviderInf instance first and then set up the provided provider options. - * See ROCMExecutionProviderInfo::FromProviderOptions() for more details. - */ - void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { - auto internal_options = onnxruntime::ROCMExecutionProviderInfo::FromProviderOptions(options); - auto& rocm_options = *reinterpret_cast(provider_options); - - rocm_options.device_id = internal_options.device_id; - rocm_options.gpu_mem_limit = internal_options.gpu_mem_limit; - rocm_options.arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); - rocm_options.miopen_conv_exhaustive_search = internal_options.miopen_conv_exhaustive_search; - rocm_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream; - rocm_options.has_user_compute_stream = internal_options.has_user_compute_stream; - // The 'has_user_compute_stream' of the OrtROCMProviderOptions instance can be set by C API UpdateROCMProviderOptionsWithValue() as well. - // We only set the 'has_user_compute_stream' of the OrtROCMProviderOptions instance if it is provided in options - if (options.find("has_user_compute_stream") != options.end()) { - rocm_options.user_compute_stream = internal_options.user_compute_stream; - } - rocm_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg; - rocm_options.enable_hip_graph = internal_options.enable_hip_graph; - rocm_options.tunable_op_enable = internal_options.tunable_op.enable; - rocm_options.tunable_op_tuning_enable = internal_options.tunable_op.tuning_enable; - rocm_options.tunable_op_max_tuning_duration_ms = internal_options.tunable_op.max_tuning_duration_ms; - } - - ProviderOptions GetProviderOptions(const void* provider_options) override { - auto& options = *reinterpret_cast(provider_options); - return onnxruntime::ROCMExecutionProviderInfo::ToProviderOptions(options); - } - - void Initialize() override { - InitializeRegistry(); - } - - void Shutdown() override { - DeleteRegistry(); - } - -} g_provider; - -} // namespace onnxruntime - -extern "C" { - -ORT_API(onnxruntime::Provider*, GetProvider) { - return &onnxruntime::g_provider; -} -} diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.h b/onnxruntime/core/providers/rocm/rocm_provider_factory.h deleted file mode 100644 index 3238d66cee479..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "onnxruntime_c_api.h" -#include "core/framework/provider_options.h" -#include "core/common/common.h" - -namespace onnxruntime { -class IAllocator; -class IDataTransfer; -struct IExecutionProviderFactory; -struct ROCMExecutionProviderInfo; -enum class ArenaExtendStrategy : int32_t; -struct ROCMExecutionProviderExternalAllocatorInfo; - -namespace rocm { -class INcclService; -} - -struct ProviderInfo_ROCM { - virtual OrtStatus* SetCurrentGpuDeviceId(_In_ int device_id) = 0; - virtual OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) = 0; - - virtual std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) = 0; - virtual std::unique_ptr CreateROCMPinnedAllocator(const char* name) = 0; - virtual std::unique_ptr CreateGPUDataTransfer() = 0; - - virtual void rocm__Impl_Cast(void* stream, const int64_t* input_data, int32_t* output_data, size_t count) = 0; - virtual void rocm__Impl_Cast(void* stream, const int32_t* input_data, int64_t* output_data, size_t count) = 0; - virtual void rocm__Impl_Cast(void* stream, const double* input_data, float* output_data, size_t count) = 0; - virtual void rocm__Impl_Cast(void* stream, const float* input_data, double* output_data, size_t count) = 0; - - virtual Status RocmCall_false(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) = 0; - virtual void RocmCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) = 0; - - virtual void CopyGpuToCpu(void* dst_ptr, const void* src_ptr, const size_t size, const OrtMemoryInfo& dst_location, const OrtMemoryInfo& src_location) = 0; - virtual void rocmMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0; - virtual void rocmMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0; - virtual int hipGetDeviceCount() = 0; - virtual void ROCMExecutionProviderInfo__FromProviderOptions(const onnxruntime::ProviderOptions& options, onnxruntime::ROCMExecutionProviderInfo& info) = 0; - -#if defined(USE_ROCM) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) && defined(ENABLE_TRAINING) - virtual onnxruntime::rocm::INcclService& GetINcclService() = 0; -#endif - - virtual std::shared_ptr CreateExecutionProviderFactory(const onnxruntime::ROCMExecutionProviderInfo& info) = 0; - virtual std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; - - // This function is the entry point to ROCM EP's UT cases. - // All tests ared only called from onnxruntime_test_all. - virtual void TestAll() { - ORT_NOT_IMPLEMENTED(__FUNCTION__, " is only implements in test code path."); - } - - protected: - ~ProviderInfo_ROCM() = default; // Can only be destroyed through a subclass instance -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory_creator.h b/onnxruntime/core/providers/rocm/rocm_provider_factory_creator.h deleted file mode 100644 index 0972fc19cdbf7..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory_creator.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/providers.h" - -struct OrtROCMProviderOptions; - -namespace onnxruntime { -// defined in provider_bridge_ort.cc -struct RocmProviderFactoryCreator { - static std::shared_ptr Create(const OrtROCMProviderOptions* provider_options); -}; -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc deleted file mode 100644 index bbd1e1befccee..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#include "core/providers/rocm/rocm_resource.h" -#include "core/providers/rocm/rocm_stream_handle.h" -#include "core/providers/rocm/rocm_common.h" -// #include "core/common/spin_pause.h" - -namespace onnxruntime { - -DeferredCpuAllocator::DeferredCpuAllocator(RocmStream& rocm_stream) : rocm_stream_(rocm_stream) { - OrtAllocator::version = ORT_API_VERSION; - OrtAllocator::Alloc = - [](OrtAllocator* this_, size_t size) { - auto self = reinterpret_cast(this_); - return self->rocm_stream_.GetCpuAllocator()->Alloc(size); - }; - OrtAllocator::Free = - [](OrtAllocator* this_, void* p) { - auto self = reinterpret_cast(this_); - self->rocm_stream_.EnqueDeferredCPUBuffer(p); - }; - OrtAllocator::Info = - [](const OrtAllocator* this_) { - auto self = reinterpret_cast(this_); - return &self->rocm_stream_.GetCpuAllocator()->Info(); - }; -} - -struct RocmNotification : public synchronize::Notification { - RocmNotification(Stream& s) : Notification(s) { - HIP_CALL_THROW(hipEventCreateWithFlags(&event_, hipEventDisableTiming)); - } - - ~RocmNotification() { - if (event_) - HIP_CALL_THROW(hipEventDestroy(event_)); - } - - void Activate() override { - // record event with hipEventBlockingSync so we can support sync on host without busy wait. - HIP_CALL_THROW(hipEventRecord(event_, static_cast(stream_.GetHandle()))); - } - - void wait_on_device(Stream& device_stream) { - ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", device_stream.GetDevice().ToString()); - // launch a wait command to the rocm stream - HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), - event_, 0)); - }; - - void wait_on_host() { - // CUDA_CALL_THROW(cudaStreamSynchronize(stream_)); - HIP_CALL_THROW(hipEventSynchronize(event_)); - } - - hipEvent_t event_; -}; - -RocmStream::RocmStream(hipStream_t stream, - const OrtDevice& device, - AllocatorPtr cpu_allocator, - bool release_cpu_buffer_on_rocm_stream, - bool own_flag, - miopenHandle_t external_miopen_handle, - hipblasHandle_t external_hipblas_handle, - const ROCMExecutionProviderInfo& ep_info) : Stream(stream, device), - own_stream_(own_flag), - cpu_allocator_(cpu_allocator), - release_cpu_buffer_on_rocm_stream_(release_cpu_buffer_on_rocm_stream), - deferred_cpu_allocator_(*this), - ep_info_(ep_info) { - if (own_flag) { - HIPBLAS_CALL_THROW(hipblasCreate(&hipblas_handle_)); - HIPBLAS_CALL_THROW(hipblasSetStream(hipblas_handle_, stream)); - MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_)); - MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream)); - } else { - hipblas_handle_ = external_hipblas_handle; - HIPBLAS_CALL_THROW(hipblasSetStream(hipblas_handle_, stream)); - miopen_handle_ = external_miopen_handle; - MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream)); - } -} - -RocmStream::~RocmStream() { - ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd()); - if (own_stream_) { - hipblasDestroy(hipblas_handle_); - miopenDestroy(miopen_handle_); - auto* handle = GetHandle(); - if (handle) - HIP_CALL_THROW(hipStreamDestroy(static_cast(handle))); - } -} - -std::unique_ptr RocmStream::CreateNotification(size_t /*num_consumers*/) { - return std::make_unique(*this); -} - -void RocmStream::Flush() { - if (own_stream_) - HIP_CALL_THROW(hipStreamSynchronize(static_cast(GetHandle()))); -} - -void RocmStream::EnqueDeferredCPUBuffer(void* cpu_buffer) { - // stream is per thread, so don't need lock - deferred_cpu_buffers_.push_back(cpu_buffer); -} - -struct CpuBuffersInfo { - // This struct stores the information needed - // to release CPU buffers allocated for GPU kernels. - // It's used to enqueue their release after - // associated GPU kernels in a ROCM stream. - - // This is a CPU allocator in ROCM EP. - // It must be the one used to allocate the - // following pointers. - AllocatorPtr allocator; - // buffers[i] is the i-th pointer added by - // AddDeferredReleaseCPUPtr for a specific - // ROCM stream. For example, this fields - // should contain all values in - // deferred_release_buffer_pool_[my_stream] - // when release my_stream's buffers. - std::unique_ptr buffers; - // CPU buffer buffers[i]. - // Number of buffer points in "buffers". - size_t n_buffers; -}; - -static void ReleaseCpuBufferCallback(void* raw_info) { - std::unique_ptr info = std::make_unique(); - info.reset(reinterpret_cast(raw_info)); - for (size_t i = 0; i < info->n_buffers; ++i) { - info->allocator->Free(info->buffers[i]); - } -} - -Status RocmStream::CleanUpOnRunEnd() { - if (deferred_cpu_buffers_.empty()) - return Status::OK(); - // Release the ownership of cpu_buffers_info so that the underlying - // object will keep alive until the end of ReleaseCpuBufferCallback. - if (release_cpu_buffer_on_rocm_stream_ && cpu_allocator_->Info().alloc_type == OrtArenaAllocator) { - std::unique_ptr cpu_buffers_info = std::make_unique(); - cpu_buffers_info->allocator = cpu_allocator_; - cpu_buffers_info->buffers = std::make_unique(deferred_cpu_buffers_.size()); - for (size_t i = 0; i < deferred_cpu_buffers_.size(); ++i) { - cpu_buffers_info->buffers[i] = deferred_cpu_buffers_.at(i); - } - cpu_buffers_info->n_buffers = deferred_cpu_buffers_.size(); - HIP_RETURN_IF_ERROR(hipLaunchHostFunc(static_cast(GetHandle()), ReleaseCpuBufferCallback, cpu_buffers_info.release())); - } else { - HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(GetHandle()))); - for (auto* buffer : deferred_cpu_buffers_) { - cpu_allocator_->Free(buffer); - } - } - - deferred_cpu_buffers_.clear(); - return Status::OK(); -} - -void* RocmStream::GetResource(int version, int id) const { - ORT_ENFORCE(version <= ORT_ROCM_RESOURCE_VERSION, "resource version unsupported!"); - void* resource{}; - switch (id) { - case RocmResource::hip_stream_t: - return reinterpret_cast(GetHandle()); - break; - case RocmResource::miopen_handle_t: - return reinterpret_cast(miopen_handle_); - break; - case RocmResource::hipblas_handle_t: - return reinterpret_cast(hipblas_handle_); - break; - case RocmResource::deferred_cpu_allocator_t: - return const_cast(&deferred_cpu_allocator_); - break; - case RocmResource::device_id_t: - return reinterpret_cast(ep_info_.device_id); - break; - case RocmResource::arena_extend_strategy_t: - return reinterpret_cast(ep_info_.arena_extend_strategy); - break; - break; - default: - break; - } - return resource; -} - -// CPU Stream command handles -void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_device(stream); -} - -void WaitRocmNotificationOnHost(Stream& /*stream*/, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_host(); -} - -void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, - const OrtDevice::DeviceType device_type, - AllocatorPtr cpu_allocator, - bool release_cpu_buffer_on_rocm_stream, - hipStream_t external_stream, - bool use_existing_stream, - miopenHandle_t external_miopen_handle, - hipblasHandle_t external_hipblas_handle, - const ROCMExecutionProviderInfo& ep_info) { - // wait rocm notification on rocm ep - stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitRocmNotificationOnDevice); - // wait rocm notification on cpu ep - stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitRocmNotificationOnHost); - if (!use_existing_stream) - stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_rocm_stream, ep_info](const OrtDevice& device) { - HIP_CALL_THROW(hipSetDevice(device.Id())); - hipStream_t stream = nullptr; - HIP_CALL_THROW(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); - // HIP_CALL_THROW(hipStreamCreate(&stream)); - return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, true, nullptr, nullptr, ep_info); - }); - else - stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, - release_cpu_buffer_on_rocm_stream, - external_stream, - external_miopen_handle, - external_hipblas_handle, - ep_info](const OrtDevice& device) { - return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, false, external_miopen_handle, external_hipblas_handle, ep_info); - }); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.h b/onnxruntime/core/providers/rocm/rocm_stream_handle.h deleted file mode 100644 index 320fb4661e987..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/providers/rocm/rocm_pch.h" -// #include "core/providers/rocm/shared_inc/rocm_utils.h" -#include "core/providers/rocm/shared_inc/rocm_call.h" -#include "core/framework/stream_handles.h" -#include "core/providers/rocm/rocm_execution_provider_info.h" - -namespace onnxruntime { - -struct RocmStream; -void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification); - -struct DeferredCpuAllocator : public OrtAllocator { - DeferredCpuAllocator(RocmStream&); - RocmStream& rocm_stream_; -}; - -struct RocmStream : Stream { - RocmStream(hipStream_t stream, - const OrtDevice& device, - AllocatorPtr cpu_allocator, - bool release_cpu_buffer_on_rocm_stream, - bool own_flag, - miopenHandle_t external_miopen_handle, - hipblasHandle_t external_hipblas_handle, - const ROCMExecutionProviderInfo& ep_info); - - ~RocmStream(); - - std::unique_ptr CreateNotification(size_t /*num_consumers*/) override; - - void Flush() override; - - Status CleanUpOnRunEnd() override; - - void EnqueDeferredCPUBuffer(void* cpu_buffer); - - bool own_stream_{true}; - - miopenHandle_t miopen_handle_{}; - - hipblasHandle_t hipblas_handle_{}; - - void* GetResource(int version, int id) const override; - - onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } - - WaitNotificationFn GetWaitNotificationFn() const override { return WaitRocmNotificationOnDevice; } - - private: - std::vector deferred_cpu_buffers_; - AllocatorPtr cpu_allocator_; - bool release_cpu_buffer_on_rocm_stream_{true}; - DeferredCpuAllocator deferred_cpu_allocator_; - const ROCMExecutionProviderInfo ep_info_; -}; - -void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, - const OrtDevice::DeviceType device_type, - AllocatorPtr cpu_allocator, - bool release_cpu_buffer_on_rocm_stream, - hipStream_t external_stream, - bool use_existing_stream, - miopenHandle_t external_miopen_handle, - hipblasHandle_t external_hipblas_handle, - const ROCMExecutionProviderInfo& ep_info); -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_utils.cu b/onnxruntime/core/providers/rocm/rocm_utils.cu deleted file mode 100644 index b817e025cedf4..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_utils.cu +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// Thrust code needs to be compiled with nvcc -#include -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/miopen_common.h" - -namespace onnxruntime { -namespace rocm { - -template -__global__ void _Fill( - T* output_data, - T val, - HIP_LONG N) { - HIP_LONG id = NumElementsPerThread * blockDim.x * blockIdx.x + threadIdx.x; - -#pragma unroll - for (int i = 0; i < NumElementsPerThread; i++) { - if (id < N) { - output_data[id] = val; - id += blockDim.x; - } - } -} - -template -void Fill(hipStream_t stream, T* output, T value, int64_t count) { - int blocksPerGrid = static_cast(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); - HIP_LONG N = static_cast(count); - _Fill - <<>>(output, value, N); -} -template -class ConstantBufferImpl : public IConstantBuffer { - public: - ConstantBufferImpl(T val) : buffer_(nullptr), count_(0), val_(val) { - } - ~ConstantBufferImpl() { - if (buffer_) - HIP_CALL_THROW(hipFree(buffer_)); - } - - virtual const T* GetBuffer(hipStream_t stream, size_t count) { - if (count > count_) { - if (buffer_) { - HIP_CALL_THROW(hipFree(buffer_)); - buffer_ = nullptr; - } - HIP_CALL_THROW(hipMalloc(&buffer_, count * sizeof(T))); - count_ = count; - - Fill(stream, buffer_, val_, count); - } - return buffer_; - } - - private: - T* buffer_; - size_t count_; - T val_; -}; - -template -std::unique_ptr> CreateConstantOnes() { - return std::make_unique>(Consts::One); -} - -template std::unique_ptr> CreateConstantOnes(); -template std::unique_ptr> CreateConstantOnes(); -template std::unique_ptr> CreateConstantOnes(); -template std::unique_ptr> CreateConstantOnes(); - -#define SPECIALIZED_FILL(T) \ - template void Fill(hipStream_t stream, T * output, T value, int64_t count); - -SPECIALIZED_FILL(int8_t) -SPECIALIZED_FILL(int16_t) -SPECIALIZED_FILL(int32_t) -SPECIALIZED_FILL(int64_t) -SPECIALIZED_FILL(float) -SPECIALIZED_FILL(double) -SPECIALIZED_FILL(__half) -SPECIALIZED_FILL(BFloat16) - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/roctracer_manager.cc b/onnxruntime/core/providers/rocm/roctracer_manager.cc deleted file mode 100644 index 51b96e56ba234..0000000000000 --- a/onnxruntime/core/providers/rocm/roctracer_manager.cc +++ /dev/null @@ -1,311 +0,0 @@ -#include -#include -#include - -#include "roctracer_manager.h" - -namespace onnxruntime { -namespace profiling { - -// allocate a 16K buffer for recording async activities -static constexpr size_t kActivityBufferSize = 0x4000; - -const std::vector RoctracerManager::hip_api_calls_to_trace = { - "hipMemcpy", - "hipMemcpy2D", - "hipMemcpyAsync", - "hipMemcpy2DAsync", - "hipMemcpyWithStream", - "hipLaunchKernel", - "hipMemset", - "hipMemsetAsync", - "hipExtModuleLaunchKernel", - "hipExtLaunchKernel", -}; - -// Implementation of RoctracerManager -RoctracerManager& RoctracerManager::GetInstance() { - static RoctracerManager instance; - return instance; -} - -RoctracerManager::~RoctracerManager() {} - -#define ROCTRACER_STATUS_RETURN_FALSE_ON_FAIL(expr_) \ - do { \ - if (expr_ != ROCTRACER_STATUS_SUCCESS) { \ - OnStopLogging(); \ - return false; \ - } \ - } while (false) - -bool RoctracerManager::OnStartLogging() { - // The following line shows up in all the samples, I do not know - // what the point is, but without it, the roctracer APIs don't work. - roctracer_set_properties(ACTIVITY_DOMAIN_HIP_API, nullptr); - - roctracer_properties_t hcc_cb_properties; - memset(&hcc_cb_properties, 0, sizeof(roctracer_properties_t)); - hcc_cb_properties.buffer_size = kActivityBufferSize; - hcc_cb_properties.buffer_callback_fun = ActivityCallback; - ROCTRACER_STATUS_RETURN_FALSE_ON_FAIL(roctracer_open_pool(&hcc_cb_properties)); - - // Enable selective activity and API callbacks for the HIP APIs - ROCTRACER_STATUS_RETURN_FALSE_ON_FAIL(roctracer_disable_domain_callback(ACTIVITY_DOMAIN_HIP_API)); - ROCTRACER_STATUS_RETURN_FALSE_ON_FAIL(roctracer_disable_domain_activity(ACTIVITY_DOMAIN_HIP_API)); - - for (auto const& logged_api : hip_api_calls_to_trace) { - uint32_t cid = 0; - ROCTRACER_STATUS_RETURN_FALSE_ON_FAIL( - roctracer_op_code(ACTIVITY_DOMAIN_HIP_API, logged_api.c_str(), &cid, nullptr)); - ROCTRACER_STATUS_RETURN_FALSE_ON_FAIL( - roctracer_enable_op_callback(ACTIVITY_DOMAIN_HIP_API, cid, ApiCallback, nullptr)); - ROCTRACER_STATUS_RETURN_FALSE_ON_FAIL( - roctracer_enable_op_activity(ACTIVITY_DOMAIN_HIP_API, cid)); - } - - // Enable activity logging in the HIP_OPS/HCC_OPS domain. - ROCTRACER_STATUS_RETURN_FALSE_ON_FAIL(roctracer_enable_domain_activity(ACTIVITY_DOMAIN_HIP_OPS)); - - roctracer_start(); - return true; -} - -void RoctracerManager::OnStopLogging() { - roctracer_disable_domain_activity(ACTIVITY_DOMAIN_HIP_API); - roctracer_disable_domain_activity(ACTIVITY_DOMAIN_HIP_OPS); - roctracer_disable_domain_callback(ACTIVITY_DOMAIN_HIP_API); - roctracer_stop(); - roctracer_flush_activity(); - roctracer_close_pool(); - api_call_args_.clear(); -} - -void RoctracerManager::ActivityCallback(const char* begin, const char* end, void* arg) { - size_t size = end - begin; - ProfilerActivityBuffer activity_buffer{reinterpret_cast(begin), size}; - auto& instance = GetInstance(); - instance.EnqueueActivityBuffer(std::move(activity_buffer)); -} - -void RoctracerManager::ApiCallback(uint32_t domain, uint32_t cid, const void* callback_data, void* arg) { - if (domain != ACTIVITY_DOMAIN_HIP_API) { - return; - } - const hip_api_data_t* data = reinterpret_cast(callback_data); - if (data->phase == ACTIVITY_API_PHASE_EXIT) { - // We only save args for async launches on the ACTIVITY_API_PHASE_ENTER phase - return; - } - - auto& instance = GetInstance(); - { - std::lock_guard lock(instance.api_call_args_mutex_); - auto& record = instance.api_call_args_[data->correlation_id]; - record.domain_ = domain; - record.cid_ = cid; - record.api_data_ = *data; - } -} - -bool RoctracerManager::PushUniqueCorrelation(uint64_t unique_cid) { - return roctracer_activity_push_external_correlation_id(unique_cid) == ROCTRACER_STATUS_SUCCESS; -} - -void RoctracerManager::PopUniqueCorrelation(uint64_t& popped_unique_cid) { - if (roctracer_activity_pop_external_correlation_id(&popped_unique_cid) != ROCTRACER_STATUS_SUCCESS) { - popped_unique_cid = 0; - } -} - -void RoctracerManager::FlushActivities() { - roctracer_flush_activity(); -} - -uint64_t RoctracerManager::GetGPUTimestampInNanoseconds() { - uint64_t result; - if (roctracer_get_timestamp(&result) != ROCTRACER_STATUS_SUCCESS) { - ORT_THROW("Could not retrieve timestamp from GPU!"); - } - return result; -} - -static inline std::string MemcpyKindToString(hipMemcpyKind kind) { - switch (kind) { - case hipMemcpyHostToHost: - return "H2H"; - case hipMemcpyHostToDevice: - return "H2D"; - case hipMemcpyDeviceToHost: - return "D2H"; - case hipMemcpyDeviceToDevice: - return "D2D"; - default: - return "Default"; - } -} - -bool RoctracerManager::CreateEventForActivityRecord(const roctracer_record_t* record, - uint64_t start_time_ns, - const ApiCallRecord& call_record, - EventRecord& event) { - std::string name; - std::unordered_map args; - - switch (call_record.cid_) { - case HIP_API_ID_hipLaunchKernel: { - auto const& launch_args = call_record.api_data_.args.hipLaunchKernel; - name = demangle(hipKernelNameRefByPtr(launch_args.function_address, - launch_args.stream)); - - args = { - {"stream", PointerToHexString((void*)(launch_args.stream))}, - {"grid_x", std::to_string(launch_args.numBlocks.x)}, - {"grid_y", std::to_string(launch_args.numBlocks.y)}, - {"grid_z", std::to_string(launch_args.numBlocks.z)}, - {"block_x", std::to_string(launch_args.dimBlocks.x)}, - {"block_y", std::to_string(launch_args.dimBlocks.y)}, - {"block_z", std::to_string(launch_args.dimBlocks.z)}}; - break; - } - - case HIP_API_ID_hipMemset: - case HIP_API_ID_hipMemsetAsync: { - auto const& launch_args = call_record.api_data_.args; - name = roctracer_op_string(call_record.domain_, call_record.cid_, 0); - - args = { - {"stream", call_record.cid_ == HIP_API_ID_hipMemset - ? "0" - : PointerToHexString((void*)launch_args.hipMemsetAsync.stream)}, - {"dst", PointerToHexString(launch_args.hipMemset.dst)}, - {"size", std::to_string(launch_args.hipMemset.sizeBytes)}, - {"value", std::to_string(launch_args.hipMemset.value)}}; - break; - } - - case HIP_API_ID_hipMemcpy: - case HIP_API_ID_hipMemcpyAsync: - case HIP_API_ID_hipMemcpyWithStream: { - auto const& launch_args = call_record.api_data_.args; - name = roctracer_op_string(call_record.domain_, call_record.cid_, 0); - - args = { - {"stream", call_record.cid_ == HIP_API_ID_hipMemcpy - ? "0" - : PointerToHexString((void*)launch_args.hipMemcpyAsync.stream)}, - {"src", PointerToHexString(launch_args.hipMemcpy.src)}, - {"dst", PointerToHexString(launch_args.hipMemcpy.dst)}, - {"kind", MemcpyKindToString(launch_args.hipMemcpy.kind)}}; - break; - } - - case HIP_API_ID_hipMemcpy2D: - case HIP_API_ID_hipMemcpy2DAsync: { - auto const& launch_args = call_record.api_data_.args; - name = roctracer_op_string(call_record.domain_, call_record.cid_, 0); - - args = { - {"stream", call_record.cid_ == HIP_API_ID_hipMemcpy2D - ? "0" - : PointerToHexString((void*)launch_args.hipMemcpy2DAsync.stream)}, - {"src", PointerToHexString(launch_args.hipMemcpy2D.src)}, - {"dst", PointerToHexString(launch_args.hipMemcpy2D.dst)}, - {"spitch", std::to_string(launch_args.hipMemcpy2D.spitch)}, - {"dpitch", std::to_string(launch_args.hipMemcpy2D.dpitch)}, - {"width", std::to_string(launch_args.hipMemcpy2D.width)}, - {"height", std::to_string(launch_args.hipMemcpy2D.height)}, - {"kind", MemcpyKindToString(launch_args.hipMemcpy2D.kind)}}; - break; - } - - case HIP_API_ID_hipExtModuleLaunchKernel: { - auto const& launch_args = call_record.api_data_.args.hipExtModuleLaunchKernel; - name = demangle(hipKernelNameRef(launch_args.f)); - - args = { - {"stream", PointerToHexString((void*)launch_args.hStream)}, - {"grid_x", std::to_string(launch_args.globalWorkSizeX)}, - {"grid_y", std::to_string(launch_args.globalWorkSizeY)}, - {"grid_z", std::to_string(launch_args.globalWorkSizeZ)}, - {"block_x", std::to_string(launch_args.localWorkSizeX)}, - {"block_y", std::to_string(launch_args.localWorkSizeY)}, - {"block_z", std::to_string(launch_args.localWorkSizeZ)}, - }; - break; - } - - case HIP_API_ID_hipExtLaunchKernel: { - auto const& launch_args = call_record.api_data_.args.hipExtLaunchKernel; - name = demangle(hipKernelNameRefByPtr(launch_args.function_address, - launch_args.stream)); - - args = { - {"stream", PointerToHexString((void*)(launch_args.stream))}, - {"grid_x", std::to_string(launch_args.numBlocks.x)}, - {"grid_y", std::to_string(launch_args.numBlocks.y)}, - {"grid_z", std::to_string(launch_args.numBlocks.z)}, - {"block_x", std::to_string(launch_args.dimBlocks.x)}, - {"block_y", std::to_string(launch_args.dimBlocks.y)}, - {"block_z", std::to_string(launch_args.dimBlocks.z)}}; - break; - } - - default: - return false; - } - - new (&event) EventRecord{ - /* cat = */ EventCategory::KERNEL_EVENT, - /* pid = */ -1, - /* tid = */ -1, - /* name = */ std::move(name), - /* ts = */ (int64_t)(this->NormalizeGPUTimestampToCPUEpoch(record->begin_ns) - start_time_ns) / 1000, - /* dur = */ (int64_t)(record->end_ns - record->begin_ns) / 1000, - /* args = */ std::move(args)}; - return true; -} - -void RoctracerManager::ProcessActivityBuffers(const std::vector& buffers, - const TimePoint& start_time) { - auto start_time_ns = std::chrono::duration_cast(start_time.time_since_epoch()).count(); - - for (auto const& buffer : buffers) { - auto current_record = reinterpret_cast(buffer.GetData()); - auto data_end = reinterpret_cast(buffer.GetData() + buffer.GetSize()); - for (; current_record < data_end; roctracer_next_record(current_record, ¤t_record)) { - EventRecord event; - if (current_record->domain == ACTIVITY_DOMAIN_EXT_API) { - NotifyNewCorrelation(current_record->correlation_id, current_record->external_id); - continue; - } else if (current_record->domain == ACTIVITY_DOMAIN_HIP_OPS) { - if (current_record->op == 1 && current_record->kind == HipOpMarker) { - // this is just a marker, ignore it. - continue; - } - - auto api_it = api_call_args_.find(current_record->correlation_id); - if (api_it == api_call_args_.end()) { - // we're not tracking this activity, ignore it - continue; - } - - auto const& call_record = api_it->second; - if (!CreateEventForActivityRecord(current_record, start_time_ns, call_record, event)) { - // No event created, skip to the next record to avoid associating an empty - // event with a client - continue; - } - } else { - // ignore the superfluous event: this is probably a HIP API callback, which - // we've had to enable to receive external correlation ids - continue; - } - // map the event to the right client - MapEventToClient(current_record->correlation_id, std::move(event)); - } - } -} - -} /* end namespace profiling */ -} /* end namespace onnxruntime */ diff --git a/onnxruntime/core/providers/rocm/roctracer_manager.h b/onnxruntime/core/providers/rocm/roctracer_manager.h deleted file mode 100644 index 52a5dccae4840..0000000000000 --- a/onnxruntime/core/providers/rocm/roctracer_manager.h +++ /dev/null @@ -1,62 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#include "core/common/gpu_profiler_common.h" -#include "core/common/inlined_containers.h" - -namespace onnxruntime { -namespace profiling { - -struct ApiCallRecord { - uint32_t domain_; - uint32_t cid_; - hip_api_data_t api_data_{}; -}; - -class RoctracerManager : public GPUTracerManager { - friend class GPUTracerManager; - - public: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(RoctracerManager); - ~RoctracerManager(); - static RoctracerManager& GetInstance(); - - protected: - bool PushUniqueCorrelation(uint64_t unique_cid); - void PopUniqueCorrelation(uint64_t& popped_unique_cid); - void OnStopLogging(); - bool OnStartLogging(); - void ProcessActivityBuffers(const std::vector& buffers, - const TimePoint& start_time); - void FlushActivities(); - uint64_t GetGPUTimestampInNanoseconds(); - - private: - RoctracerManager() = default; - static void ActivityCallback(const char* begin, const char* end, void* arg); - static void ApiCallback(uint32_t domain, uint32_t cid, const void* callback_data, void* arg); - bool CreateEventForActivityRecord(const roctracer_record_t* record, uint64_t start_time_ns, - const ApiCallRecord& call_record, EventRecord& event); - - // Some useful constants for processing activity buffers - static constexpr uint32_t HipOpMarker = 4606; - - std::mutex api_call_args_mutex_; - InlinedHashMap api_call_args_; - - // The api calls to track - static const std::vector hip_api_calls_to_trace; -}; /* class RoctracerManager */ - -} /* end namespace profiling */ -} /* end namespace onnxruntime*/ diff --git a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h deleted file mode 100644 index 9d32fcb65d0d5..0000000000000 --- a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h +++ /dev/null @@ -1,958 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/backward_guard.h" -#include "core/providers/rocm/rocm_common.h" - -#define ORT_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -#if ORT_ROCBLAS_VERSION_DECIMAL >= 242 -#define FLAG rocblas_gemm_flags_fp16_alt_impl -#else -#define FLAG 0 -#endif -// needed to work around calling rocblas API instead of hipblas API -static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op) { - switch (op) { - case HIPBLAS_OP_N: - return rocblas_operation_none; - case HIPBLAS_OP_T: - return rocblas_operation_transpose; - case HIPBLAS_OP_C: - return rocblas_operation_conjugate_transpose; - } - assert(0 && "HIPBLAS_STATUS_INVALID_ENUM"); -} -static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) { - switch (error) { - case rocblas_status_size_unchanged: - case rocblas_status_size_increased: - case rocblas_status_success: - return HIPBLAS_STATUS_SUCCESS; - case rocblas_status_invalid_handle: - return HIPBLAS_STATUS_NOT_INITIALIZED; - case rocblas_status_not_implemented: - return HIPBLAS_STATUS_NOT_SUPPORTED; - case rocblas_status_invalid_pointer: - case rocblas_status_invalid_size: - case rocblas_status_invalid_value: - return HIPBLAS_STATUS_INVALID_VALUE; - case rocblas_status_memory_error: - return HIPBLAS_STATUS_ALLOC_FAILED; - case rocblas_status_internal_error: - return HIPBLAS_STATUS_INTERNAL_ERROR; - default: - assert(0 && "ROCBLAS_STATUS_INVALID_ENUM"); - return HIPBLAS_STATUS_INTERNAL_ERROR; - } -} - -using namespace onnxruntime; - -inline int get_flag() { - int result = BackwardPassGuard::is_backward_pass() ? FLAG : 0; - return result; -} - -// Generalize library calls to be use in template functions - -// hipblas - -// gemm -inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const float* A, int lda, - const float* B, int ldb, - const float* beta, - float* C, int ldc) { - return hipblasGemmEx(handle, - transa, - transb, - m, n, k, - alpha, - A, HIP_R_32F, lda, - B, HIP_R_32F, ldb, - beta, - C, HIP_R_32F, ldc, - HIPBLAS_COMPUTE_32F, - HIPBLAS_GEMM_DEFAULT); -} - -inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const double* alpha, - const double* A, int lda, - const double* B, int ldb, - const double* beta, - double* C, int ldc) { - return hipblasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); -} - -inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const half* alpha, - const half* A, int lda, - const half* B, int ldb, - const half* beta, - half* C, int ldc) { - float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); - float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast(beta)); - return rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle)handle, - hipOperationToRocOperation(transa), - hipOperationToRocOperation(transb), - m, n, k, - &h_a, - A, rocblas_datatype_f16_r, lda, - B, rocblas_datatype_f16_r, ldb, - &h_b, - C, rocblas_datatype_f16_r, ldc, - C, rocblas_datatype_f16_r, ldc, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag())); -} - -inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const half* A, int lda, - const half* B, int ldb, - const float* beta, - half* C, int ldc) { - return rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle)handle, - hipOperationToRocOperation(transa), - hipOperationToRocOperation(transb), - m, n, k, - alpha, - A, rocblas_datatype_f16_r, lda, - B, rocblas_datatype_f16_r, ldb, - beta, - C, rocblas_datatype_f16_r, ldc, - C, rocblas_datatype_f16_r, ldc, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag())); -} - -inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const half* A, int lda, - const half* B, int ldb, - const float* beta, - half* C, int ldc, - const hipDeviceProp_t&, - bool /*use_tf32*/) { - return hipblasGemmHelper(handle, - transa, - transb, - m, n, k, - alpha, - A, lda, - B, ldb, - beta, - C, ldc); -} - -inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const BFloat16* alpha, - const BFloat16* A, int lda, - const BFloat16* B, int ldb, - const BFloat16* beta, - BFloat16* C, int ldc) { - float h_a = alpha->ToFloat(); - float h_b = beta->ToFloat(); - - // accumulating in FP32 - return hipblasGemmEx(handle, - transa, - transb, - m, n, k, - &h_a, - A, HIP_R_16BF, lda, - B, HIP_R_16BF, ldb, - &h_b, - C, HIP_R_16BF, ldc, - HIPBLAS_COMPUTE_32F, - HIPBLAS_GEMM_DEFAULT); -} - -// Compatible for function call with extra arguments (see cublasGemmHelper) -template -hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const Scalar* alpha, - const Scalar* A, int lda, - const Scalar* B, int ldb, - const Scalar* beta, - Scalar* C, int ldc, - const hipDeviceProp_t&, - bool /*use_tf32*/) { - return hipblasGemmHelper(handle, - transa, - transb, - m, n, k, - alpha, - A, lda, - B, ldb, - beta, - C, ldc); -} - -// batched gemm -inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const float* Aarray[], int lda, - const float* Barray[], int ldb, - const float* beta, - float* Carray[], int ldc, - int batchCount) { - return hipblasGemmBatchedEx(handle, - transa, - transb, - m, n, k, - alpha, - (const void**)Aarray, HIP_R_32F, lda, - (const void**)Barray, HIP_R_32F, ldb, - beta, - (void**)Carray, HIP_R_32F, ldc, - batchCount, - HIPBLAS_COMPUTE_32F, - HIPBLAS_GEMM_DEFAULT); -} -inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const double* alpha, - const double* Aarray[], int lda, - const double* Barray[], int ldb, - const double* beta, - double* Carray[], int ldc, - int batchCount) { - return hipblasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); -} -inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const half* alpha, - const half* Aarray[], int lda, - const half* Barray[], int ldb, - const half* beta, - half* Carray[], int ldc, - int batchCount) { - float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); - float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast(beta)); - return rocBLASStatusToHIPStatus(rocblas_gemm_batched_ex((rocblas_handle)handle, - hipOperationToRocOperation(transa), - hipOperationToRocOperation(transb), - m, n, k, - &h_a, - (const void**)Aarray, rocblas_datatype_f16_r, lda, - (const void**)Barray, rocblas_datatype_f16_r, ldb, - &h_b, - (void**)Carray, rocblas_datatype_f16_r, ldc, - (void**)Carray, rocblas_datatype_f16_r, ldc, - batchCount, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag())); -} - -inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const BFloat16* alpha, - const BFloat16* Aarray[], int lda, - const BFloat16* Barray[], int ldb, - const BFloat16* beta, - BFloat16* Carray[], int ldc, - int batch_count) { - float h_a = alpha->ToFloat(); - float h_b = beta->ToFloat(); - - // accumulating in FP32 - return hipblasGemmBatchedEx(handle, - transa, - transb, - m, n, k, - &h_a, - (const void**)Aarray, HIP_R_16BF, lda, - (const void**)Barray, HIP_R_16BF, ldb, - &h_b, - (void**)Carray, HIP_R_16BF, ldc, - batch_count, - HIPBLAS_COMPUTE_32F, - HIPBLAS_GEMM_DEFAULT); -} - -// strided batched gemm -inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const float* A, int lda, - long long int strideA, - const float* B, int ldb, - long long int strideB, - const float* beta, - float* C, int ldc, - long long int strideC, - int batchCount) { - return hipblasGemmStridedBatchedEx(handle, - transa, - transb, - m, n, k, - alpha, - A, HIP_R_32F, lda, strideA, - B, HIP_R_32F, ldb, strideB, - beta, - C, HIP_R_32F, ldc, strideC, - batchCount, - HIPBLAS_COMPUTE_32F, - HIPBLAS_GEMM_DEFAULT); -} - -inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const double* alpha, - const double* A, int lda, - long long int strideA, - const double* B, int ldb, - long long int strideB, - const double* beta, - double* C, int ldc, - long long int strideC, - int batchCount) { - return hipblasDgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); -} - -inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const __half* alpha, - const __half* A, int lda, - long long int strideA, - const __half* B, int ldb, - long long int strideB, - const __half* beta, - __half* C, int ldc, - long long int strideC, - int batchCount) { - float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); - float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast(beta)); - return rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle, - hipOperationToRocOperation(transa), - hipOperationToRocOperation(transb), - m, n, k, - &h_a, - A, rocblas_datatype_f16_r, lda, strideA, - B, rocblas_datatype_f16_r, ldb, strideB, - &h_b, - C, rocblas_datatype_f16_r, ldc, strideC, - C, rocblas_datatype_f16_r, ldc, strideC, - batchCount, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag())); -} - -inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const __half* A, int lda, - intmax_t strideA, - const __half* B, int ldb, - intmax_t strideB, - const float* beta, - __half* C, int ldc, - intmax_t strideC, - int batchCount) { - return rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle, - hipOperationToRocOperation(transa), - hipOperationToRocOperation(transb), - m, n, k, - alpha, - A, rocblas_datatype_f16_r, lda, strideA, - B, rocblas_datatype_f16_r, ldb, strideB, - beta, - C, rocblas_datatype_f16_r, ldc, strideC, - C, rocblas_datatype_f16_r, ldc, strideC, - batchCount, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag())); -} - -inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const BFloat16* alpha, - const BFloat16* A, int lda, - intmax_t strideA, - const BFloat16* B, int ldb, - intmax_t strideB, - const BFloat16* beta, - BFloat16* C, int ldc, - intmax_t strideC, - int batch_count) { - float h_a = alpha->ToFloat(); - float h_b = beta->ToFloat(); - // accumulating in FP32 - return hipblasGemmStridedBatchedEx(handle, - transa, - transb, - m, n, k, - &h_a, - A, HIP_R_16BF, lda, strideA, - B, HIP_R_16BF, ldb, strideB, - &h_b, - C, HIP_R_16BF, ldc, strideC, - batch_count, - HIPBLAS_COMPUTE_32F, - HIPBLAS_GEMM_DEFAULT); -} - -// Compatible for function call with with extra arguments (see cublasGemmStridedBatchedHelper) -template -hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const Scalar* alpha, - const Scalar* A, int lda, - intmax_t strideA, - const Scalar* B, int ldb, - intmax_t strideB, - const Scalar* beta, - Scalar* C, int ldc, - intmax_t strideC, - int batchCount, - const hipDeviceProp_t&, - bool /*use_tf32*/) { - return hipblasGemmStridedBatchedHelper(handle, - transa, - transb, - m, n, k, - alpha, - A, lda, strideA, - B, ldb, strideB, - beta, - C, ldc, strideC, - batchCount); -} - -inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const __half* A, int lda, - intmax_t strideA, - const __half* B, int ldb, - intmax_t strideB, - const float* beta, - __half* C, int ldc, - intmax_t strideC, - int batchCount, - const hipDeviceProp_t&, - bool /*use_tf32*/) { - return hipblasGemmStridedBatchedHelper(handle, - transa, - transb, - m, n, k, - alpha, - A, lda, strideA, - B, ldb, strideB, - beta, - C, ldc, strideC, - batchCount); -} - -// transpose using geam -inline hipblasStatus_t hipblasTransposeHelper(hipStream_t /*stream*/, hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) { - return hipblasSgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); -} -inline hipblasStatus_t hipblasTransposeHelper(hipStream_t /*stream*/, hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, double* C, int ldc) { - return hipblasDgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); -} - -hipblasStatus_t hipblasTransposeHelper(hipStream_t stream, hipblasHandle_t, hipblasOperation_t, hipblasOperation_t, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); - -// copy -inline hipblasStatus_t hipblasCopyHelper(hipStream_t /*stream*/, hipblasHandle_t handle, int n, const float* x, int incx, float* y, int incy) { - return hipblasScopy(handle, n, x, incx, y, incy); -} -inline hipblasStatus_t hipblasCopyHelper(hipStream_t /*stream*/, hipblasHandle_t handle, int n, const double* x, int incx, double* y, int incy) { - return hipblasDcopy(handle, n, x, incx, y, incy); -} -hipblasStatus_t hipblasCopyHelper(hipStream_t stream, hipblasHandle_t handle, int n, const half* x, int incx, half* y, int incy); -hipblasStatus_t hipblasCopyHelper(hipStream_t stream, hipblasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); - -// rocblas - -// gemm -inline rocblas_status rocblasGemmHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const float* alpha, - const float* A, int lda, - const float* B, int ldb, - const float* beta, - float* C, int ldc) { - return rocblas_gemm_ex(handle, - transa, - transb, - m, n, k, - alpha, - A, rocblas_datatype_f32_r, lda, - B, rocblas_datatype_f32_r, ldb, - beta, - C, rocblas_datatype_f32_r, ldc, - C, rocblas_datatype_f32_r, ldc, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); -} - -inline rocblas_status rocblasGemmHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const double* alpha, - const double* A, int lda, - const double* B, int ldb, - const double* beta, - double* C, int ldc) { - return rocblas_dgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); -} - -inline rocblas_status rocblasGemmHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const half* alpha, - const half* A, int lda, - const half* B, int ldb, - const half* beta, - half* C, int ldc) { - float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); - float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast(beta)); - return rocblas_gemm_ex(handle, - transa, - transb, - m, n, k, - &h_a, - A, rocblas_datatype_f16_r, lda, - B, rocblas_datatype_f16_r, ldb, - &h_b, - C, rocblas_datatype_f16_r, ldc, - C, rocblas_datatype_f16_r, ldc, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag()); -} - -inline rocblas_status rocblasGemmHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const float* alpha, - const half* A, int lda, - const half* B, int ldb, - const float* beta, - half* C, int ldc) { - return rocblas_gemm_ex(handle, - transa, - transb, - m, n, k, - alpha, - A, rocblas_datatype_f16_r, lda, - B, rocblas_datatype_f16_r, ldb, - beta, - C, rocblas_datatype_f16_r, ldc, - C, rocblas_datatype_f16_r, ldc, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag()); -} - -inline rocblas_status rocblasGemmHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const float* alpha, - const half* A, int lda, - const half* B, int ldb, - const float* beta, - half* C, int ldc, - const hipDeviceProp_t&, - bool /*use_tf32*/) { - return rocblasGemmHelper(handle, - transa, - transb, - m, n, k, - alpha, - A, lda, - B, ldb, - beta, - C, ldc); -} - -inline rocblas_status rocblasGemmHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const BFloat16* alpha, - const BFloat16* A, int lda, - const BFloat16* B, int ldb, - const BFloat16* beta, - BFloat16* C, int ldc) { - float h_a = alpha->ToFloat(); - float h_b = beta->ToFloat(); - - // accumulating in FP32 - return rocblas_gemm_ex(handle, - transa, - transb, - m, n, k, - &h_a, - A, rocblas_datatype_bf16_r, lda, - B, rocblas_datatype_bf16_r, ldb, - &h_b, - C, rocblas_datatype_bf16_r, ldc, - C, rocblas_datatype_bf16_r, ldc, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); -} - -// Compatible for function call with extra arguments (see cublasGemmHelper) -template -rocblas_status rocblasGemmHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const Scalar* alpha, - const Scalar* A, int lda, - const Scalar* B, int ldb, - const Scalar* beta, - Scalar* C, int ldc, - const hipDeviceProp_t&, - bool /*use_tf32*/) { - return rocblasGemmHelper(handle, - transa, - transb, - m, n, k, - alpha, - A, lda, - B, ldb, - beta, - C, ldc); -} - -// batched gemm -inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const float* alpha, - const float* Aarray[], int lda, - const float* Barray[], int ldb, - const float* beta, - float* Carray[], int ldc, - int batchCount) { - return rocblas_gemm_batched_ex(handle, - transa, - transb, - m, n, k, - alpha, - (const void**)Aarray, rocblas_datatype_f32_r, lda, - (const void**)Barray, rocblas_datatype_f32_r, ldb, - beta, - (void**)Carray, rocblas_datatype_f32_r, ldc, - (void**)Carray, rocblas_datatype_f32_r, ldc, - batchCount, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); -} -inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const double* alpha, - const double* Aarray[], int lda, - const double* Barray[], int ldb, - const double* beta, - double* Carray[], int ldc, - int batchCount) { - return rocblas_dgemm_batched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); -} -inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const half* alpha, - const half* Aarray[], int lda, - const half* Barray[], int ldb, - const half* beta, - half* Carray[], int ldc, - int batchCount) { - float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); - float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast(beta)); - return rocblas_gemm_batched_ex(handle, - transa, - transb, - m, n, k, - &h_a, - (const void**)Aarray, rocblas_datatype_f16_r, lda, - (const void**)Barray, rocblas_datatype_f16_r, ldb, - &h_b, - (void**)Carray, rocblas_datatype_f16_r, ldc, - (void**)Carray, rocblas_datatype_f16_r, ldc, - batchCount, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag()); -} - -inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const BFloat16* alpha, - const BFloat16* Aarray[], int lda, - const BFloat16* Barray[], int ldb, - const BFloat16* beta, - BFloat16* Carray[], int ldc, - int batch_count) { - float h_a = alpha->ToFloat(); - float h_b = beta->ToFloat(); - - // accumulating in FP32 - return rocblas_gemm_batched_ex(handle, - transa, - transb, - m, n, k, - &h_a, - (const void**)Aarray, rocblas_datatype_bf16_r, lda, - (const void**)Barray, rocblas_datatype_bf16_r, ldb, - &h_b, - (void**)Carray, rocblas_datatype_bf16_r, ldc, - (void**)Carray, rocblas_datatype_bf16_r, ldc, - batch_count, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); -} - -// strided batched gemm -inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const float* alpha, - const float* A, int lda, - long long int strideA, - const float* B, int ldb, - long long int strideB, - const float* beta, - float* C, int ldc, - long long int strideC, - int batchCount) { - return rocblas_gemm_strided_batched_ex(handle, - transa, - transb, - m, n, k, - alpha, - A, rocblas_datatype_f32_r, lda, strideA, - B, rocblas_datatype_f32_r, ldb, strideB, - beta, - C, rocblas_datatype_f32_r, ldc, strideC, - C, rocblas_datatype_f32_r, ldc, strideC, - batchCount, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); -} - -inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const double* alpha, - const double* A, int lda, - long long int strideA, - const double* B, int ldb, - long long int strideB, - const double* beta, - double* C, int ldc, - long long int strideC, - int batchCount) { - return rocblas_dgemm_strided_batched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); -} - -inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const __half* alpha, - const __half* A, int lda, - long long int strideA, - const __half* B, int ldb, - long long int strideB, - const __half* beta, - __half* C, int ldc, - long long int strideC, - int batchCount) { - float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); - float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast(beta)); - return rocblas_gemm_strided_batched_ex(handle, - transa, - transb, - m, n, k, - &h_a, - A, rocblas_datatype_f16_r, lda, strideA, - B, rocblas_datatype_f16_r, ldb, strideB, - &h_b, - C, rocblas_datatype_f16_r, ldc, strideC, - C, rocblas_datatype_f16_r, ldc, strideC, - batchCount, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag()); -} - -inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const float* alpha, - const __half* A, int lda, - intmax_t strideA, - const __half* B, int ldb, - intmax_t strideB, - const float* beta, - __half* C, int ldc, - intmax_t strideC, - int batchCount) { - return rocblas_gemm_strided_batched_ex(handle, - transa, - transb, - m, n, k, - alpha, - A, rocblas_datatype_f16_r, lda, strideA, - B, rocblas_datatype_f16_r, ldb, strideB, - beta, - C, rocblas_datatype_f16_r, ldc, strideC, - C, rocblas_datatype_f16_r, ldc, strideC, - batchCount, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag()); -} - -inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const BFloat16* alpha, - const BFloat16* A, int lda, - intmax_t strideA, - const BFloat16* B, int ldb, - intmax_t strideB, - const BFloat16* beta, - BFloat16* C, int ldc, - intmax_t strideC, - int batch_count) { - float h_a = alpha->ToFloat(); - float h_b = beta->ToFloat(); - // accumulating in FP32 - return rocblas_gemm_strided_batched_ex(handle, - transa, - transb, - m, n, k, - &h_a, - A, rocblas_datatype_bf16_r, lda, strideA, - B, rocblas_datatype_bf16_r, ldb, strideB, - &h_b, - C, rocblas_datatype_bf16_r, ldc, strideC, - C, rocblas_datatype_bf16_r, ldc, strideC, - batch_count, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); -} - -// Compatible for function call with with extra arguments (see cublasGemmStridedBatchedHelper) -template -rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const Scalar* alpha, - const Scalar* A, int lda, - intmax_t strideA, - const Scalar* B, int ldb, - intmax_t strideB, - const Scalar* beta, - Scalar* C, int ldc, - intmax_t strideC, - int batchCount, - const hipDeviceProp_t&, - bool /*use_tf32*/) { - return rocblasGemmStridedBatchedHelper(handle, - transa, - transb, - m, n, k, - alpha, - A, lda, strideA, - B, ldb, strideB, - beta, - C, ldc, strideC, - batchCount); -} - -inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, int n, int k, - const float* alpha, - const __half* A, int lda, - intmax_t strideA, - const __half* B, int ldb, - intmax_t strideB, - const float* beta, - __half* C, int ldc, - intmax_t strideC, - int batchCount, - const hipDeviceProp_t&, - bool /*use_tf32*/) { - return rocblasGemmStridedBatchedHelper(handle, - transa, - transb, - m, n, k, - alpha, - A, lda, strideA, - B, ldb, strideB, - beta, - C, ldc, strideC, - batchCount); -} -bool CanUse_hipblasTransposeHelper_MLFloat16(int m, int n); -hipblasStatus_t hipblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation, rocblas_operation, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); diff --git a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h deleted file mode 100644 index 563ae17fcdb3b..0000000000000 --- a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/providers/rocm/rocm_pch.h" - -namespace onnxruntime { - -// ----------------------------------------------------------------------- -// Error handling -// ----------------------------------------------------------------------- - -template -std::conditional_t RocmCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); - -#define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) -#define ROCBLAS_CALL(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) -#define HIPBLAS_CALL(expr) (RocmCall((expr), #expr, "HIPBLAS", HIPBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define ROCMSMI_CALL(expr) (RocmCall((expr), #expr, "ROCMSMI", RSMI_STATUS_SUCCESS, "", __FILE__, __LINE__)) - -#define HIPSPARSE_CALL(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define HIPRAND_CALL(expr) (RocmCall((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define MIOPEN_CALL(expr) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, "", __FILE__, __LINE__)) -#define MIOPEN_CALL2(expr, m) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, m, __FILE__, __LINE__)) - -#define HIPFFT_CALL(expr) (RocmCall((expr), #expr, "HIPFFT", HIPFFT_SUCCESS, "", __FILE__, __LINE__)) - -#define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) -#define ROCBLAS_CALL_THROW(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) -#define HIPBLAS_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPBLAS", HIPBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define ROCMSMI_CALL_THROW(expr) (RocmCall((expr), #expr, "ROCMSMI", RSMI_STATUS_SUCCESS, "", __FILE__, __LINE__)) - -#define HIPSPARSE_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define HIPRAND_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) - -#define MIOPEN_CALL_THROW(expr) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, "", __FILE__, __LINE__)) -#define MIOPEN_CALL_THROW2(expr, m) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, m, __FILE__, __LINE__)) -#define HIPFFT_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPFFT", HIPFFT_SUCCESS, "", __FILE__, __LINE__)) - -#ifdef ORT_USE_NCCL -#define NCCL_CALL(expr) (RocmCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) -#define NCCL_CALL_THROW(expr) (RocmCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) -#endif - -#ifdef USE_HIPBLASLT -#define HIPBLASLT_CALL(expr) (RocmCall((expr), #expr, "hipBLASLt", HIPBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define HIPBLASLT_CALL_THROW(expr) (RocmCall((expr), #expr, "hipBLASLt", HIPBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#endif - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/symbols.txt b/onnxruntime/core/providers/rocm/symbols.txt deleted file mode 100644 index 57a9f2bb602c2..0000000000000 --- a/onnxruntime/core/providers/rocm/symbols.txt +++ /dev/null @@ -1 +0,0 @@ -OrtSessionOptionsAppendExecutionProvider_ROCM diff --git a/onnxruntime/core/providers/rocm/tunable/gemm.cu b/onnxruntime/core/providers/rocm/tunable/gemm.cu deleted file mode 100644 index f40440e55be9b..0000000000000 --- a/onnxruntime/core/providers/rocm/tunable/gemm.cu +++ /dev/null @@ -1,280 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define _GEMM_H_KEEP_SIGNATURE_DEFINES -#include "core/providers/rocm/tunable/gemm.h" - -#include -#include - -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm_rocblas.h" -#include "core/providers/rocm/tunable/gemm_tunable.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { - -namespace row_major { - -namespace { -// a simple utility function that normalize alpha or beta to the desired datatype by an optional casting -template -inline DesiredT NormalizeScalar(ScalarT v) { - if constexpr (!std::is_same_v && std::is_same_v) { - return ToHipType::FromFloat(std::forward(v)); - } else { - return v; - } -} -} // namespace - -template -inline GEMM(T, ScalarT) { - GemmParams params; - params.tuning_ctx = tuning_ctx; - params.stream = stream; - params.handle = (rocblas_handle)handle; - - params.opa = opa; - params.opb = opb; - params.m = m; - params.n = n; - params.k = k; - params.alpha = NormalizeScalar(alpha); - params.a = a; - params.lda = lda; - params.b = b; - params.ldb = ldb; - params.beta = NormalizeScalar(beta); - params.c = c; - params.ldc = ldc; - - if (tuning_ctx->IsTunableOpEnabled()) { - if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::GemmTunableOp gemm{}; - return gemm(¶ms); - } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::GemmTunableOp gemm{}; - return gemm(¶ms); - } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::GemmTunableOp gemm{}; - return gemm(¶ms); - } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::GemmTunableOp gemm{}; - return gemm(¶ms); - } - } - - return internal::RocBlasGemmOp(¶ms); -} - -template -inline BATCHED_GEMM(T, ScalarT) { - BatchedGemmParams params; - params.tuning_ctx = tuning_ctx; - params.stream = stream; - params.handle = (rocblas_handle)handle; - - params.opa = opa; - params.opb = opb; - params.m = m; - params.n = n; - params.k = k; - params.alpha = NormalizeScalar(alpha); - params.as = as; - params.lda = lda; - params.bs = bs; - params.ldb = ldb; - params.beta = NormalizeScalar(beta); - params.cs = cs; - params.ldc = ldc; - params.batch = batch; - - if (tuning_ctx->IsTunableOpEnabled()) { - if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::BatchedGemmTunableOp gemm{}; - return gemm(¶ms); - } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::BatchedGemmTunableOp gemm{}; - return gemm(¶ms); - } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::BatchedGemmTunableOp gemm{}; - return gemm(¶ms); - } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::BatchedGemmTunableOp gemm{}; - return gemm(¶ms); - } - } - - return internal::RocBlasBatchedGemmOp(¶ms); -} - -template -inline STRIDED_BATCHED_GEMM(T, ScalarT) { - StridedBatchedGemmParams params; - params.tuning_ctx = tuning_ctx; - params.stream = stream; - params.handle = (rocblas_handle)handle; - - params.opa = opa; - params.opb = opb; - params.m = m; - params.n = n; - params.k = k; - params.alpha = NormalizeScalar(alpha); - params.a = a; - params.lda = lda; - params.stride_a = stride_a; - params.b = b; - params.ldb = ldb; - params.stride_b = stride_b; - params.beta = NormalizeScalar(beta); - params.c = c; - params.ldc = ldc; - params.stride_c = stride_c; - params.batch = batch; - - if (tuning_ctx->IsTunableOpEnabled()) { - if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::StridedBatchedGemmTunableOp gemm{}; - return gemm(¶ms); - } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::StridedBatchedGemmTunableOp gemm{}; - return gemm(¶ms); - } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::StridedBatchedGemmTunableOp gemm{}; - return gemm(¶ms); - } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::StridedBatchedGemmTunableOp gemm{}; - return gemm(¶ms); - } - } - - return internal::RocBlasStridedBatchedGemmOp(¶ms); -} - -#define CALL_GEMM(T, ScalarT) \ - Gemm(tuning_ctx, stream, handle, \ - opa, opb, \ - m, n, k, \ - alpha, a, lda, b, ldb, \ - beta, c, ldc) - -#define CALL_BATCHED_GEMM(T, ScalarT) \ - BatchedGemm( \ - tuning_ctx, stream, handle, \ - opa, opb, \ - m, n, k, \ - alpha, as, lda, bs, ldb, \ - beta, cs, ldc, batch) - -#define CALL_STRIDED_BATCHED_GEMM(T, ScalarT) \ - StridedBatchedGemm( \ - tuning_ctx, stream, handle, \ - opa, opb, \ - m, n, k, \ - alpha, \ - a, lda, stride_a, \ - b, ldb, stride_b, \ - beta, c, ldc, stride_c, \ - batch) - -// clang-format off -GEMM(double, double ) { return CALL_GEMM(double, double ); } -GEMM(float, float ) { return CALL_GEMM(float, float ); } -GEMM(half, half ) { return CALL_GEMM(half, half ); } -GEMM(BFloat16, BFloat16) { return CALL_GEMM(BFloat16, BFloat16); } -GEMM(double, float ) { return CALL_GEMM(double, float ); } -GEMM(half, float ) { return CALL_GEMM(half, float ); } -GEMM(BFloat16, float ) { return CALL_GEMM(BFloat16, float ); } - -BATCHED_GEMM(double, double ) { return CALL_BATCHED_GEMM(double, double ); } -BATCHED_GEMM(float, float ) { return CALL_BATCHED_GEMM(float, float ); } -BATCHED_GEMM(half, half ) { return CALL_BATCHED_GEMM(half, half ); } -BATCHED_GEMM(BFloat16, BFloat16) { return CALL_BATCHED_GEMM(BFloat16, BFloat16); } -BATCHED_GEMM(double, float ) { return CALL_BATCHED_GEMM(double, float ); } -BATCHED_GEMM(half, float ) { return CALL_BATCHED_GEMM(half, float ); } -BATCHED_GEMM(BFloat16, float ) { return CALL_BATCHED_GEMM(BFloat16, float ); } - -STRIDED_BATCHED_GEMM(double, double ) { return CALL_STRIDED_BATCHED_GEMM(double, double ); } -STRIDED_BATCHED_GEMM(float, float ) { return CALL_STRIDED_BATCHED_GEMM(float, float ); } -STRIDED_BATCHED_GEMM(half, half ) { return CALL_STRIDED_BATCHED_GEMM(half, half ); } -STRIDED_BATCHED_GEMM(BFloat16, BFloat16) { return CALL_STRIDED_BATCHED_GEMM(BFloat16, BFloat16); } -STRIDED_BATCHED_GEMM(double, float ) { return CALL_STRIDED_BATCHED_GEMM(double, float ); } -STRIDED_BATCHED_GEMM(half, float ) { return CALL_STRIDED_BATCHED_GEMM(half, float ); } -STRIDED_BATCHED_GEMM(BFloat16, float ) { return CALL_STRIDED_BATCHED_GEMM(BFloat16, float ); } -// clang-format on - -#undef CALL_GEMM -#undef CALL_BATCHED_GEMM -#undef CALL_STRIDED_BATCHED_GEMM - -} // namespace row_major - -namespace column_major { - -#define CALL_GEMM_WITH_AB_SWAPPED(T, ScalarT) \ - row_major::Gemm(tuning_ctx, stream, handle, \ - opb, opa, \ - n, m, k, \ - alpha, b, ldb, a, lda, \ - beta, c, ldc) - -#define CALL_BATCHED_GEMM_WITH_AB_SWAPPED(T, ScalarT) \ - row_major::BatchedGemm( \ - tuning_ctx, stream, handle, \ - opb, opa, \ - n, m, k, \ - alpha, bs, ldb, as, lda, \ - beta, cs, ldc, batch) - -#define CALL_STRIDED_BATCHED_GEMM_WITH_AB_SWAPPED(T, ScalarT) \ - row_major::StridedBatchedGemm( \ - tuning_ctx, stream, handle, \ - opb, opa, \ - n, m, k, \ - alpha, \ - b, ldb, stride_b, \ - a, lda, stride_a, \ - beta, \ - c, ldc, stride_c, \ - batch) - -// clang-format off -GEMM(double, double ) { return CALL_GEMM_WITH_AB_SWAPPED(double, double ); } -GEMM(float, float ) { return CALL_GEMM_WITH_AB_SWAPPED(float, float ); } -GEMM(half, half ) { return CALL_GEMM_WITH_AB_SWAPPED(half, half ); } -GEMM(BFloat16, BFloat16) { return CALL_GEMM_WITH_AB_SWAPPED(BFloat16, BFloat16); } -GEMM(double, float ) { return CALL_GEMM_WITH_AB_SWAPPED(double, float ); } -GEMM(half, float ) { return CALL_GEMM_WITH_AB_SWAPPED(half, float ); } -GEMM(BFloat16, float ) { return CALL_GEMM_WITH_AB_SWAPPED(BFloat16, float ); } - -BATCHED_GEMM(double, double ) { return CALL_BATCHED_GEMM_WITH_AB_SWAPPED(double, double ); } -BATCHED_GEMM(float, float ) { return CALL_BATCHED_GEMM_WITH_AB_SWAPPED(float, float ); } -BATCHED_GEMM(half, half ) { return CALL_BATCHED_GEMM_WITH_AB_SWAPPED(half, half ); } -BATCHED_GEMM(BFloat16, BFloat16) { return CALL_BATCHED_GEMM_WITH_AB_SWAPPED(BFloat16, BFloat16); } -BATCHED_GEMM(double, float ) { return CALL_BATCHED_GEMM_WITH_AB_SWAPPED(double, float ); } -BATCHED_GEMM(half, float ) { return CALL_BATCHED_GEMM_WITH_AB_SWAPPED(half, float ); } -BATCHED_GEMM(BFloat16, float ) { return CALL_BATCHED_GEMM_WITH_AB_SWAPPED(BFloat16, float ); } - -STRIDED_BATCHED_GEMM(double, double ) { return CALL_STRIDED_BATCHED_GEMM_WITH_AB_SWAPPED(double, double ); } -STRIDED_BATCHED_GEMM(float, float ) { return CALL_STRIDED_BATCHED_GEMM_WITH_AB_SWAPPED(float, float ); } -STRIDED_BATCHED_GEMM(half, half ) { return CALL_STRIDED_BATCHED_GEMM_WITH_AB_SWAPPED(half, half ); } -STRIDED_BATCHED_GEMM(BFloat16, BFloat16) { return CALL_STRIDED_BATCHED_GEMM_WITH_AB_SWAPPED(BFloat16, BFloat16); } -STRIDED_BATCHED_GEMM(double, float ) { return CALL_STRIDED_BATCHED_GEMM_WITH_AB_SWAPPED(double, float ); } -STRIDED_BATCHED_GEMM(half, float ) { return CALL_STRIDED_BATCHED_GEMM_WITH_AB_SWAPPED(half, float ); } -STRIDED_BATCHED_GEMM(BFloat16, float ) { return CALL_STRIDED_BATCHED_GEMM_WITH_AB_SWAPPED(BFloat16, float ); } -// clang-format on - -#undef CALL_GEMM_WITH_AB_SWAPPED -#undef CALL_BATCHED_GEMM_WITH_AB_SWAPPED -#undef CALL_STRIDED_BATCHED_GEMM_WITH_AB_SWAPPED - -} // namespace column_major - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm.h b/onnxruntime/core/providers/rocm/tunable/gemm.h deleted file mode 100644 index 5b06535cb3862..0000000000000 --- a/onnxruntime/core/providers/rocm/tunable/gemm.h +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/status.h" -#include "core/framework/float16.h" -#include "core/providers/rocm/tunable/gemm_common.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { - -#define GEMM(T, ScalarT) \ - common::Status Gemm( \ - RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ - BlasOp opa, BlasOp opb, \ - std::int64_t m, std::int64_t n, std::int64_t k, \ - ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ - ScalarT beta, T* c, std::int64_t ldc) - -#define BATCHED_GEMM(T, ScalarT) \ - common::Status BatchedGemm( \ - RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ - BlasOp opa, BlasOp opb, \ - std::int64_t m, std::int64_t n, std::int64_t k, \ - ScalarT alpha, \ - const T** as, std::int64_t lda, \ - const T** bs, std::int64_t ldb, \ - ScalarT beta, \ - T** cs, std::int64_t ldc, std::int64_t batch) - -#define STRIDED_BATCHED_GEMM(T, ScalarT) \ - common::Status StridedBatchedGemm( \ - RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ - BlasOp opa, BlasOp opb, \ - std::int64_t m, std::int64_t n, std::int64_t k, \ - ScalarT alpha, \ - const T* a, std::int64_t lda, std::int64_t stride_a, \ - const T* b, std::int64_t ldb, std::int64_t stride_b, \ - ScalarT beta, \ - T* c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch) - -namespace row_major { - -GEMM(double, double); -GEMM(float, float); -GEMM(half, half); -GEMM(BFloat16, BFloat16); -GEMM(double, float); -GEMM(half, float); -GEMM(BFloat16, float); - -BATCHED_GEMM(double, double); -BATCHED_GEMM(float, float); -BATCHED_GEMM(half, half); -BATCHED_GEMM(BFloat16, BFloat16); -BATCHED_GEMM(double, float); -BATCHED_GEMM(half, float); -BATCHED_GEMM(BFloat16, float); - -STRIDED_BATCHED_GEMM(double, double); -STRIDED_BATCHED_GEMM(float, float); -STRIDED_BATCHED_GEMM(half, half); -STRIDED_BATCHED_GEMM(BFloat16, BFloat16); -STRIDED_BATCHED_GEMM(double, float); -STRIDED_BATCHED_GEMM(half, float); -STRIDED_BATCHED_GEMM(BFloat16, float); - -} // namespace row_major - -// TODO(anyone): the caller should not need to swap the params a and b manually, but all the current callsites are -// doing so. It is cumbersome and unintuitive. At the moment, this namespace only ease the porting from old direct -// rocblas_gemm* calls to tunable gemm calls. After all porting of all callsites, if there is no column_major usecase -// left, then we shall remove this namespace, finally. -namespace column_major { - -GEMM(double, double); -GEMM(float, float); -GEMM(half, half); -GEMM(BFloat16, BFloat16); -GEMM(double, float); -GEMM(half, float); -GEMM(BFloat16, float); - -BATCHED_GEMM(double, double); -BATCHED_GEMM(float, float); -BATCHED_GEMM(half, half); -BATCHED_GEMM(BFloat16, BFloat16); -BATCHED_GEMM(double, float); -BATCHED_GEMM(half, float); -BATCHED_GEMM(BFloat16, float); - -STRIDED_BATCHED_GEMM(double, double); -STRIDED_BATCHED_GEMM(float, float); -STRIDED_BATCHED_GEMM(half, half); -STRIDED_BATCHED_GEMM(BFloat16, BFloat16); -STRIDED_BATCHED_GEMM(double, float); -STRIDED_BATCHED_GEMM(half, float); -STRIDED_BATCHED_GEMM(BFloat16, float); - -} // namespace column_major - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#ifndef _GEMM_H_KEEP_SIGNATURE_DEFINES -#undef GEMM -#endif diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh deleted file mode 100644 index b342bd6bc8a72..0000000000000 --- a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#ifdef USE_COMPOSABLE_KERNEL -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif - -#include "core/providers/rocm/tunable/gemm_common.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -#ifdef USE_COMPOSABLE_KERNEL - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using Nop = ck::tensor_operation::element_wise::PassThrough; - -template -auto GetCKGemmTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using ALayout = typename CKBlasOpAdaptor::type; - using BLayout = typename CKBlasOpAdaptor::type; - using DeviceGemm = ck::tensor_operation::device::DeviceGemm< - ALayout, BLayout, Row, - CKDataType, CKDataType, CKDataType, - Nop, Nop, Nop>; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = impl->GetTypeString(); - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams* params) -> Status { - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->alpha != one || params->beta != zero, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0"); - - auto nop = Nop{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, params->ldc, - nop, nop, nop); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); - } - return ret; -} - -template -auto GetCKStreamKGemmTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using ALayout = typename CKBlasOpAdaptor::type; - using BLayout = typename CKBlasOpAdaptor::type; - using DeviceGemm = ck::tensor_operation::device::DeviceGemmStreamK< - ALayout, BLayout, Row, - CKDataType, CKDataType, CKDataType, - Nop, Nop, Nop>; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = impl->GetTypeString(); - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams* params) -> Status { - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->alpha != one || params->beta != zero, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0"); - - auto nop = Nop{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, params->ldc, - nop, nop, nop); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); - } - return ret; -} - -template -auto GetCKSplitKGemmTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using ALayout = typename CKBlasOpAdaptor::type; - using BLayout = typename CKBlasOpAdaptor::type; - using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK< - ALayout, BLayout, Row, - CKDataType, CKDataType, CKDataType, - Nop, Nop, Nop>; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto num_split : {4, 16, 64}) { - auto instances = InstanceFactory::GetInstances(); - for (auto&& impl : instances) { - auto type_string = impl->GetTypeString() + "_SplitK" + std::to_string(num_split); - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemm_op = [num_split, impl = std::move(impl), invoker = std::move(invoker)](const GemmParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->k < 128 * num_split, "k=", params->k, " is too small, it makes no sense to use this split-k gemm."); - - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->alpha != one || params->beta != zero, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0"); - - auto nop = Nop{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, params->ldc, - nop, nop, nop, num_split); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); - } - } - return ret; -} - -template -auto GetCKStridedBatchedGemmTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using ALayout = typename CKBlasOpAdaptor::type; - using BLayout = typename CKBlasOpAdaptor::type; - using DeviceStridedBatchedGemm = ck::tensor_operation::device::DeviceBatchedGemm< - ALayout, BLayout, Row, - CKDataType, CKDataType, CKDataType, - Nop, Nop, Nop>; - using InstanceFactory = - ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = impl->GetTypeString(); - - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const StridedBatchedGemmParams* params) -> Status { - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->alpha != one || params->beta != zero, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0"); - - auto nop = Nop{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, params->ldc, - params->stride_a, params->stride_b, params->stride_c, - params->batch, - nop, nop, nop); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); - } - return ret; -} -#else -struct Row {}; -struct Col {}; -#endif // USE_COMPOSABLE_KERNEL - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_common.h b/onnxruntime/core/providers/rocm/tunable/gemm_common.h deleted file mode 100644 index ca96e4a61003b..0000000000000 --- a/onnxruntime/core/providers/rocm/tunable/gemm_common.h +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/framework/float8.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { - -enum class BlasOp { - N = 0, - T = 1, - NonTrans = 0, - Trans = 1, -}; - -inline std::string BlasOpToString(BlasOp op) { - switch (op) { - case BlasOp::N: - return "N"; - case BlasOp::T: - return "T"; - // following is unreachable, compiler is producing false-positive warning, unfortunately. - default: - ORT_THROW("unreachable"); - } -} - -// We don't assume the implementation is row-majored or column-majored. But for testing convenience, we assume all -// our wrappers have row-majored convention, since it is the native layout to numpy and pytorch. -template -struct GemmParams : tunable::OpParams { - std::string Signature() const override { - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); - } - - rocblas_handle handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - T alpha; - const T* a; - int64_t lda; - const T* b; - int64_t ldb; - T beta; - T* c; - int64_t ldc; -}; - -template -struct BatchedGemmParams : tunable::OpParams { - std::string Signature() const override { - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, "_B", batch); - } - - rocblas_handle handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - T alpha; - const T** as; - int64_t lda; - const T** bs; - int64_t ldb; - T beta; - T** cs; - int64_t ldc; - int64_t batch; -}; - -template -struct StridedBatchedGemmParams : tunable::OpParams { - std::string Signature() const override { - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, "_B", batch); - } - - rocblas_handle handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - T alpha; - const T* a; - int64_t lda; - int64_t stride_a; - const T* b; - int64_t ldb; - int64_t stride_b; - T beta; - T* c; - int64_t ldc; - int64_t stride_c; - int64_t batch; -}; - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h deleted file mode 100644 index 486ce5bfb731a..0000000000000 --- a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#ifdef USE_HIPBLASLT -#include -#include -#include "core/providers/rocm/tunable/gemm_ck.cuh" -#include "core/providers/rocm/rocm_execution_provider.h" -#include "core/providers/rocm/rocm_stream_handle.h" -#endif - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "core/common/common.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -using onnxruntime::contrib::rocm::blas::GemmFastGeluParams; - -#ifdef USE_HIPBLASLT - -// For large K and small M/N, K dim will be split to multiple workgroups and buffers, -// which will require additional workspace. Here we set the max workspace size to 32MB. -constexpr const size_t kHipBlasLtMaxWorkSpaceSizeInBytes = 32 * 1024 * 1024; - -enum ActivationType { - NONE = 0, - RELU = 1, - GELU = 2, -}; - -template -constexpr hipDataType HipBlasDataTypeFor(); - -template <> -constexpr hipDataType HipBlasDataTypeFor() { - return HIP_R_32F; -} - -template <> -constexpr hipDataType HipBlasDataTypeFor() { - return HIP_R_16F; -} - -template <> -constexpr hipDataType HipBlasDataTypeFor() { - return HIP_R_16BF; -} - -template <> -constexpr hipDataType HipBlasDataTypeFor() { - return HIP_R_64F; -} - -template -constexpr hipblasOperation_t MapBlasOpToHipBlasLt() { - if constexpr (Op == BlasOp::NonTrans) { - return HIPBLAS_OP_N; - } - return HIPBLAS_OP_T; -} - -template -int GetBatchCountFromParams(const ParamsT* params) { - ORT_UNUSED_PARAMETER(params); - return 1; -} - -template -int GetBatchCountFromParams(const StridedBatchedGemmParams* params) { - return params->batch; -} - -template -const T* GetBiasFromParams(const ParamsT* params) { - ORT_UNUSED_PARAMETER(params); - return nullptr; -} - -template -const T* GetBiasFromParams(const GemmFastGeluParams* params) { - return params->bias; -} - -template -std::string TypeStringFor() { - if constexpr (std::is_same_v>) { - return "Gemm"; - } else if constexpr (std::is_same_v>) { - return "StridedBatchedGemm"; - } else if constexpr (std::is_same_v>) { - return "GemmFastGelu"; - } - return "UnknownType"; -} - -template -auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationType::NONE) { - hipblasLtHandle_t handle; - HIPBLASLT_CALL_THROW(hipblasLtCreate(&handle)); - - hipblasOperation_t trans_a = MapBlasOpToHipBlasLt(); - hipblasOperation_t trans_b = MapBlasOpToHipBlasLt(); - hipDataType in_out_datatype = HipBlasDataTypeFor(); - std::vector heuristic_result; - - HIPBLASLT_CALL_THROW(hipblaslt_ext::getAllAlgos(handle, - hipblaslt_ext::GemmType::HIPBLASLT_GEMM, - trans_a, - trans_b, - in_out_datatype, - in_out_datatype, - in_out_datatype, - in_out_datatype, - HIPBLAS_COMPUTE_32F, - heuristic_result)); - HIPBLASLT_CALL_THROW(hipblasLtDestroy(handle)); - - // Sort heuristic_result by algo index to make sure the order of returned algos is deterministic. - std::sort(heuristic_result.begin(), - heuristic_result.end(), - [](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) { - return hipblaslt_ext::getIndexFromAlgo(a.algo) < hipblaslt_ext::getIndexFromAlgo(b.algo); - }); - - int returned_algo_count = heuristic_result.size(); - std::vector>> ret; - for (int i = 0; i < returned_algo_count; i++) { - hipblasLtMatmulAlgo_t algo = heuristic_result[i].algo; - int algo_index = hipblaslt_ext::getIndexFromAlgo(algo); - auto hipblaslt_gemm_op = [=](const ParamsT* params) -> Status { - hipblasLtHandle_t op_handle; - HIPBLASLT_RETURN_IF_ERROR(hipblasLtCreate(&op_handle)); - - // Note: properties of original matrices A and B are swapped. - int64_t lda = (params->opb == BlasOp::N) ? params->n : params->k; - int64_t ldb = (params->opa == BlasOp::N) ? params->k : params->m; - int64_t ldc = params->n; - int64_t stride_a = (params->opb == BlasOp::N) ? lda * params->k : lda * params->n; - int64_t stride_b = (params->opa == BlasOp::N) ? ldb * params->m : ldb * params->k; - int64_t stride_c = ldc * params->m; - float alpha = static_cast(params->alpha); - float beta = static_cast(params->beta); - int row_a, col_a, row_b, col_b, row_c, col_c; - row_a = lda; - col_a = (params->opb == BlasOp::N) ? params->k : params->n; - row_b = ldb; - col_b = (params->opa == BlasOp::N) ? params->m : params->k; - row_c = ldc; - col_c = params->m; - - hipblasLtMatrixLayout_t mat_a, mat_b, mat_c; - hipblasLtMatmulDesc_t matmul; - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); - - int batch = GetBatchCountFromParams(params); - if (batch > 1) { - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( - mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( - mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a))); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( - mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( - mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b))); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( - mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( - mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c))); - } - - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(int32_t))); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(int32_t))); - - // Deduce enable_bias from params - auto d_bias = GetBiasFromParams(params); - bool enable_bias = d_bias != nullptr; - - hipblasLtEpilogue_t epilogue; - switch (activation_type) { - case ActivationType::NONE: - epilogue = enable_bias ? HIPBLASLT_EPILOGUE_BIAS : HIPBLASLT_EPILOGUE_DEFAULT; - break; - case ActivationType::RELU: - epilogue = enable_bias ? HIPBLASLT_EPILOGUE_RELU_BIAS : HIPBLASLT_EPILOGUE_RELU; - break; - case ActivationType::GELU: - epilogue = enable_bias ? HIPBLASLT_EPILOGUE_GELU_BIAS : HIPBLASLT_EPILOGUE_GELU; - break; - default: - throw std::runtime_error("Unsupported activation type for HipBlasLtMatMul"); - } - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); - - if (enable_bias) { - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &d_bias, sizeof(void*))); - } - - size_t workspace_size = 0; - hipblasLtMatmulAlgo_t algo_i = algo; - auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle, - matmul, - &alpha, - mat_a, - mat_b, - &beta, - mat_c, - mat_c, - algo_i, - workspace_size); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - status != HIPBLAS_STATUS_SUCCESS, - "[hipBLASLt] Solution #", i, " failed: algo ", algo_index, " not supported"); - - IAllocatorUniquePtr workspace_buffer; - if (workspace_size > 0) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(workspace_size > kHipBlasLtMaxWorkSpaceSizeInBytes, - "Workspace size exceeds limit (32M): ", workspace_size); - workspace_size = kHipBlasLtMaxWorkSpaceSizeInBytes; - workspace_buffer = params->tuning_ctx->GetScratchBuffer(workspace_size, params->stream); - } - - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmul(op_handle, - matmul, - &alpha, - params->b, - mat_a, - params->a, - mat_b, - &beta, - params->c, - mat_c, - params->c, - mat_c, - &algo_i, - workspace_buffer.get(), - workspace_size, - params->StreamHandle())); - - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescDestroy(matmul)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_a)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_b)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_c)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtDestroy(op_handle)); - return Status::OK(); - }; - std::string type_string = onnxruntime::MakeString( - TypeStringFor(), "HipBlasLt_", i, "_algo_", algo_index); - ret.emplace_back(type_string, std::move(hipblaslt_gemm_op)); - } - return ret; -} - -template -auto GetHipBlasLtGemmTypeStringAndOps() { - return GetHipBlasLtTypeStringAndOps>(); -} - -template -auto GetHipBlasLtStridedBatchedGemmTypeStringAndOps() { - return GetHipBlasLtTypeStringAndOps>(); -} - -template -auto GetHipBlasLtGemmFastGeluTypeStringAndOps() { - return GetHipBlasLtTypeStringAndOps>(ActivationType::GELU); -} - -#endif // USE_HIPBLASLT - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h deleted file mode 100644 index a391d1af8868c..0000000000000 --- a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h +++ /dev/null @@ -1,375 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -// RAII style guard to set stream and restore original stream for rocblas_handle -class RocblasHandleStreamGuard { - public: - RocblasHandleStreamGuard(rocblas_handle handle, hipStream_t stream) : handle_{handle} { - ROCBLAS_CALL_THROW(rocblas_get_stream(handle_, &original_stream_)); - ROCBLAS_CALL_THROW(rocblas_set_stream(handle_, stream)); - } - - ~RocblasHandleStreamGuard() { - ROCBLAS_CALL_THROW(rocblas_set_stream(handle_, original_stream_)); - } - - ORT_DISALLOW_COPY_AND_ASSIGNMENT(RocblasHandleStreamGuard); - - private: - rocblas_handle handle_; - hipStream_t original_stream_; -}; - -#ifdef USE_ROCBLAS_EXTENSION_API - -template -constexpr rocblas_datatype RocBlasDataTypeFor(); - -template <> -constexpr rocblas_datatype RocBlasDataTypeFor() { - return rocblas_datatype_f32_r; -} - -template <> -constexpr rocblas_datatype RocBlasDataTypeFor() { - return rocblas_datatype_f16_r; -} - -template <> -constexpr rocblas_datatype RocBlasDataTypeFor() { - return rocblas_datatype_f64_r; -} - -template <> -constexpr rocblas_datatype RocBlasDataTypeFor() { - return rocblas_datatype_bf16_r; -} - -template -constexpr rocblas_datatype RocBlasComputeTypeFor(); - -template <> -constexpr rocblas_datatype RocBlasComputeTypeFor() { - return rocblas_datatype_f32_r; -} - -template <> -constexpr rocblas_datatype RocBlasComputeTypeFor() { - // Note that we're returning the _compute_ type for a given datatype. - // As of 12/2022, using compute type FP16 for 16-bit floats was much - // slower than using compute type FP32. So we use FP32 compute even for - // FP16 datatypes. This is how GEMM is implemented even in the function - // rocblasGemmHelper (see fpgeneric.h) - return rocblas_datatype_f32_r; -} - -template <> -constexpr rocblas_datatype RocBlasComputeTypeFor() { - return rocblas_datatype_f64_r; -} - -template <> -constexpr rocblas_datatype RocBlasComputeTypeFor() { - // Note that we're returning the _compute_ type for a given datatype. - // As of 12/2022, using compute type FP16 for 16-bit floats was much - // slower than using compute type FP32. So we use FP32 compute even for - // BF16 datatypes. This is how GEMM is implemented even in the function - // rocblasGemmHelper (see fpgeneric.h) - return rocblas_datatype_f32_r; -} - -template -auto DoCastForHalfOrBfloat16(const T fp) { - return fp; -} - -template <> -inline auto DoCastForHalfOrBfloat16(const half fp) { - // alpha and beta should be the same as compute_type, in half case it is float. - float h = onnxruntime::math::halfToFloat(*reinterpret_cast(&fp)); - return h; -} - -template <> -inline auto DoCastForHalfOrBfloat16(const BFloat16 fp) { - // alpha and beta should be the same as compute_type, in bfloat16 case it is float. - float h = fp.ToFloat(); - return h; -} - -template -auto GetRocBlasGemmTypeStringAndOps() { - rocblas_handle handle; - ROCBLAS_CALL_THROW(rocblas_create_handle(&handle)); - - int solution_size; - auto input_output_type = RocBlasDataTypeFor(); - auto compute_type = RocBlasComputeTypeFor(); - - // Get the number of available solutions - ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle, - input_output_type, - input_output_type, - compute_type, - rocblas_gemm_flags_none, - nullptr, - &solution_size)); - - std::vector solutions(solution_size); - - // Get the list of available solutions - ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle, - input_output_type, - input_output_type, - compute_type, - rocblas_gemm_flags_none, - solutions.data(), - &solution_size)); - - ROCBLAS_CALL_THROW(rocblas_destroy_handle(handle)); - - // Sort the solutions in ascending order to make the solution vector deterministic across runs - std::sort(solutions.begin(), solutions.end()); - - std::vector>>> ret; - for (size_t i = 0; i < solutions.size(); ++i) { - auto solution = solutions[i]; - auto rocblas_gemm_op = [=](const GemmParams* params) -> Status { - auto h_a = DoCastForHalfOrBfloat16(params->alpha); - auto h_b = DoCastForHalfOrBfloat16(params->beta); - auto status = rocblas_gemm_ex( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &h_a, - params->b, input_output_type, params->ldb, - params->a, input_output_type, params->lda, - &h_b, - params->c, input_output_type, params->ldc, - params->c, input_output_type, params->ldc, - compute_type, - rocblas_gemm_algo_solution_index, - solution, - rocblas_gemm_flags_none); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - status != rocblas_status_success, - "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status)); - - return Status::OK(); - }; - ret.emplace_back(std::make_pair( - onnxruntime::MakeString("RocBlasGemm_", i, "_sol_", solution), std::move(rocblas_gemm_op))); - } - return ret; -} - -template -auto GetRocBlasBatchedGemmTypeStringAndOps() { - rocblas_handle handle; - ROCBLAS_CALL_THROW(rocblas_create_handle(&handle)); - - int solution_size; - auto input_output_type = RocBlasDataTypeFor(); - auto compute_type = RocBlasComputeTypeFor(); - - // Get the number of available solutions - ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions_by_type(handle, - input_output_type, - input_output_type, - compute_type, - rocblas_gemm_flags_none, - nullptr, - &solution_size)); - - std::vector solutions(solution_size); - - // Get the list of available solutions - ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions_by_type(handle, - input_output_type, - input_output_type, - compute_type, - rocblas_gemm_flags_none, - solutions.data(), - &solution_size)); - - ROCBLAS_CALL_THROW(rocblas_destroy_handle(handle)); - - // Sort the solutions in ascending order to make the solution vector deterministic across runs - std::sort(solutions.begin(), solutions.end()); - - std::vector>>> ret; - for (size_t i = 0; i < solutions.size(); ++i) { - auto solution = solutions[i]; - auto rocblas_gemm_op = [=](const BatchedGemmParams* params) -> Status { - auto h_a = DoCastForHalfOrBfloat16(params->alpha); - auto h_b = DoCastForHalfOrBfloat16(params->beta); - auto status = rocblas_gemm_batched_ex( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &h_a, - params->bs, input_output_type, params->ldb, - params->as, input_output_type, params->lda, - &h_b, - params->cs, input_output_type, params->ldc, - params->cs, input_output_type, params->ldc, - params->batch, - compute_type, - rocblas_gemm_algo_solution_index, - solution, - rocblas_gemm_flags_none); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - status != rocblas_status_success, - "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status)); - - return Status::OK(); - }; - ret.emplace_back(std::make_pair( - onnxruntime::MakeString("RocBlasBatchedGemm_", i, "_sol_", solution), std::move(rocblas_gemm_op))); - } - return ret; -} - -template -auto GetRocBlasStridedBatchedGemmTypeStringAndOps() { - rocblas_handle handle; - ROCBLAS_CALL_THROW(rocblas_create_handle(&handle)); - - int solution_size; - auto input_output_type = RocBlasDataTypeFor(); - auto compute_type = RocBlasComputeTypeFor(); - - // Get the number of available solutions - ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle, - input_output_type, - input_output_type, - compute_type, - rocblas_gemm_flags_none, - nullptr, - &solution_size)); - - std::vector solutions(solution_size); - - // Get the list of available solutions - ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle, - input_output_type, - input_output_type, - compute_type, - rocblas_gemm_flags_none, - solutions.data(), - &solution_size)); - - ROCBLAS_CALL_THROW(rocblas_destroy_handle(handle)); - - // Sort the solutions in ascending order to make the solution vector deterministic across runs - std::sort(solutions.begin(), solutions.end()); - - std::vector>>> ret; - for (size_t i = 0; i < solutions.size(); ++i) { - auto solution = solutions[i]; - auto rocblas_gemm_op = [=](const StridedBatchedGemmParams* params) -> Status { - auto h_a = DoCastForHalfOrBfloat16(params->alpha); - auto h_b = DoCastForHalfOrBfloat16(params->beta); - auto status = rocblas_gemm_strided_batched_ex( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &h_a, - params->b, input_output_type, params->ldb, params->stride_b, - params->a, input_output_type, params->lda, params->stride_a, - &h_b, - params->c, input_output_type, params->ldc, params->stride_c, - params->c, input_output_type, params->ldc, params->stride_c, - params->batch, - compute_type, - rocblas_gemm_algo_solution_index, - solution, - rocblas_gemm_flags_none); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - status != rocblas_status_success, - "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status)); - - return Status::OK(); - }; - ret.emplace_back(std::make_pair( - onnxruntime::MakeString("RocBlasStridedBatchedGemm_", i, "_sol_", solution), std::move(rocblas_gemm_op))); - } - return ret; -} - -#endif // USE_ROCBLAS_EXTENSION_API - -template -Status RocBlasGemmOp(const GemmParams* params) { - RocblasHandleStreamGuard guard(params->handle, params->StreamHandle()); - // NOTE: rocblas assumes the storage is column-majored, swapping A and B makes it have the same interface - // as those with row-majored convention. That is, if you treat the storage as row-majored but view the matrices as - // transposed, then by using the property Transpose(A*B) = Tranpose(B)*Transpose(A), the correctness is obvious. - return ROCBLAS_CALL(rocblasGemmHelper( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &(params->alpha), - params->b, params->ldb, - params->a, params->lda, - &(params->beta), - params->c, params->ldc)); -} - -template -Status RocBlasBatchedGemmOp(const BatchedGemmParams* params) { - RocblasHandleStreamGuard guard(params->handle, params->StreamHandle()); - return ROCBLAS_CALL(rocblasGemmBatchedHelper( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &(params->alpha), - params->bs, params->ldb, - params->as, params->lda, - &(params->beta), - params->cs, params->ldc, - params->batch)); -} - -template -Status RocBlasStridedBatchedGemmOp(const StridedBatchedGemmParams* params) { - RocblasHandleStreamGuard guard(params->handle, params->StreamHandle()); - return ROCBLAS_CALL(rocblasGemmStridedBatchedHelper( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &(params->alpha), - params->b, params->ldb, params->stride_b, - params->a, params->lda, params->stride_a, - &(params->beta), - params->c, params->ldc, params->stride_c, - params->batch)); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh deleted file mode 100644 index 9228287fbbb89..0000000000000 --- a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/tunable/gemm_ck.cuh" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "core/providers/rocm/tunable/gemm_hipblaslt.h" -#include "core/providers/rocm/tunable/gemm_rocblas.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -bool IsZero(T v) { - return v == 0.0f; -} - -template <> -bool IsZero(BFloat16 v) { - return v.val == 0; -} - -template <> -bool IsZero(half v) { - return __half2float(v) == 0.0f; -} - -template -class GemmTunableOp : public TunableOp> { - public: - GemmTunableOp() { - this->RegisterOp(RocBlasGemmOp); - -#ifdef USE_HIPBLASLT - for (auto&& [_, op] : GetHipBlasLtGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - -#ifdef USE_ROCBLAS_EXTENSION_API - for (auto&& [_, op] : GetRocBlasGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKStreamKGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - for (auto&& [_, op] : GetCKSplitKGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - } - - const GemmParams* PreTuning(const GemmParams* params) override { - if (!IsZero(params->beta)) { - // When beta != 0, C buffer is used as an input as well as an output. We need to create a proxy params for the - // tuning process. Otherwise, tuning will cause the C buffer been updated accumulatedly, say, we tune it for n - // iterations, then during tuning C^(1) = alpha A B + beta C^(0), ..., C^(n) = alpha A B + beta C^(n-1). And for - // the actual run after tuning, the result will be C^(n+1), whereas what we want is C^(1). This only happens if - // the tuning's FindFastest is invoked. - // - // Note, C^(i) is the C at i-th iteration. - GemmParams* proxy = new GemmParams(); - *proxy = *params; - HIP_CALL_THROW(hipMalloc(&(proxy->c), proxy->m * proxy->ldc * sizeof(T))); - return proxy; - } - - return params; - } - - void PostTuning(const GemmParams* params) override { - if (!IsZero(params->beta)) { - HIP_CALL_THROW(hipFree(params->c)); - delete params; - } - } -}; - -template -class BatchedGemmTunableOp : public TunableOp> { - public: - BatchedGemmTunableOp() { - this->RegisterOp(RocBlasBatchedGemmOp); - -#ifdef USE_ROCBLAS_EXTENSION_API - for (auto&& [_, op] : GetRocBlasBatchedGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - } - - const BatchedGemmParams* PreTuning(const BatchedGemmParams* params) override { - if (!IsZero(params->beta)) { - // See GemmTunableOp::PreTuning for more details - BatchedGemmParams* proxy = new BatchedGemmParams(); - *proxy = *params; - - // malloc a large buffer and then slice it - const int single_buffer_bytes = CeilDiv(proxy->m * proxy->ldc * sizeof(T), 128) * 128; - T* buffer; - HIP_CALL_THROW(hipMalloc(&buffer, proxy->batch * single_buffer_bytes)); - std::vector buffer_ptrs(proxy->batch, nullptr); - for (int i = 0; i < proxy->batch; i++) { - // note the following is offseted by bytes - buffer_ptrs[i] = reinterpret_cast(reinterpret_cast(buffer) + i * single_buffer_bytes); - } - - // copy all ptrs to device - HIP_CALL_THROW(hipMalloc(&(proxy->cs), proxy->batch * sizeof(T*))); - HIP_CALL_THROW(hipMemcpy(proxy->cs, buffer_ptrs.data(), buffer_ptrs.size() * sizeof(T*), hipMemcpyHostToDevice)); - return proxy; - } - - return params; - } - - void PostTuning(const BatchedGemmParams* params) override { - if (!IsZero(params->beta)) { - T* buffer; - HIP_CALL_THROW(hipMemcpy(&buffer, params->cs, sizeof(T*), hipMemcpyDeviceToHost)); - HIP_CALL_THROW(hipFree(buffer)); - HIP_CALL_THROW(hipFree(params->cs)); - delete params; - } - } -}; - -template -class StridedBatchedGemmTunableOp : public TunableOp> { - public: - StridedBatchedGemmTunableOp() { - this->RegisterOp(RocBlasStridedBatchedGemmOp); - -#ifdef USE_HIPBLASLT - for (auto&& [_, op] : GetHipBlasLtStridedBatchedGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - -#ifdef USE_ROCBLAS_EXTENSION_API - for (auto&& [_, op] : GetRocBlasStridedBatchedGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKStridedBatchedGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - } - - const StridedBatchedGemmParams* PreTuning(const StridedBatchedGemmParams* params) override { - if (!IsZero(params->beta)) { - // See GemmTunableOp::PreTuning for more details - StridedBatchedGemmParams* proxy = new StridedBatchedGemmParams(); - *proxy = *params; - HIP_CALL_THROW(hipMalloc(&(proxy->c), proxy->batch * proxy->stride_c * sizeof(T))); - return proxy; - } - - return params; - } - - void PostTuning(const StridedBatchedGemmParams* params) override { - if (!IsZero(params->beta)) { - HIP_CALL_THROW(hipFree(params->c)); - delete params; - } - } -}; - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tunable.h b/onnxruntime/core/providers/rocm/tunable/rocm_tunable.h deleted file mode 100644 index 95fa4f37d7f68..0000000000000 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tunable.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/rocm/rocm_common.h" // avoid provider_api.h ODR violation -#include "core/framework/tunable.h" -#include "core/providers/rocm/rocm_execution_provider_info.h" -#include "core/providers/rocm/tunable/rocm_tuning_context.h" -#include "core/providers/rocm/tunable/util.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { - -using OpParams = OpParams; - -template -using Op = Op; - -class Timer; -template -using TunableOp = TunableOp; - -} // namespace tunable -} // namespace rocm - -// As a convenience for authoring TunableOp in contrib namespace -namespace contrib { -namespace rocm { -using onnxruntime::rocm::tunable::Op; -using onnxruntime::rocm::tunable::OpParams; -using onnxruntime::rocm::tunable::RocmTuningContext; -using onnxruntime::rocm::tunable::TunableOp; -} // namespace rocm -} // namespace contrib - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc deleted file mode 100644 index 88e5fde189ba2..0000000000000 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/tunable/rocm_tuning_context.h" - -#include "core/providers/shared_library/provider_api.h" -#include "core/framework/tuning_context.h" -#define TUNING_CONTEXT_IMPL -#include "core/framework/tuning_context_impl.h" -#undef TUNING_CONTEXT_IMPL -#include "core/providers/rocm/rocm_execution_provider.h" -#include "core/providers/rocm/rocm_stream_handle.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { - -static std::string GetHipVersion() { - int version; - HIP_CALL_THROW(hipRuntimeGetVersion(&version)); - return std::to_string(version); -} - -static Status ValidateHipVersion(const std::string& value) { - auto current = GetHipVersion(); - ORT_RETURN_IF(current != value, "HIP runtime version mismatch: tuning results produced with HIP ", value, - ", onnxruntime currently run with HIP ", current); - return Status::OK(); -} - -static std::string GetRocBlasVersion() { - char buf[64]; - ROCBLAS_CALL_THROW(rocblas_get_version_string(buf, 256)); - buf[63] = '\0'; - return buf; -} - -static Status ValidateRocBlasVersion(const std::string& value) { - auto current = GetRocBlasVersion(); - ORT_RETURN_IF(current != value, "rocblas runtime version mismatch: tuning results produced with rocblas ", value, - ", onnxruntime currently run with rocblas ", current); - return Status::OK(); -} - -std::string RocmTuningResultsValidator::GetOrtBuildConfig() const { - std::ostringstream oss; -#ifdef USE_COMPOSABLE_KERNEL - oss << "USE_CK=" << 1 << "|"; -#ifdef USE_COMPOSABLE_KERNEL_CK_TILE - oss << "USE_CKTILE=" << 1 << "|"; -#endif -#else - oss << "USE_CK=" << 0 << "|"; -#endif - -#ifdef USE_ROCBLAS_EXTENSION_API - oss << "USE_ROCBLAS_EXTENSION_API=" << 1 << "|"; -#else - oss << "USE_ROCBLAS_EXTENSION_API=" << 0 << "|"; -#endif - -#ifdef USE_HIPBLASLT - oss << "USE_HIPBLASLT=" << 1 << "|"; -#else - oss << "USE_HIPBLASLT=" << 0 << "|"; -#endif - return oss.str(); -} - -std::string RocmTuningResultsValidator::GetDeviceModel() const { - return ep_->GetDeviceProp().name; -} - -Status RocmTuningResultsValidator::ValidateDeviceModel(const std::string& value) const { - auto current = GetDeviceModel(); - ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value, - ", onnxruntime currently run with device ", current); - return Status::OK(); -} - -RocmTuningResultsValidator::RocmTuningResultsValidator(ROCMExecutionProvider* ep) : ep_{ep} { - RegisterValidator("HIP_VERSION", GetHipVersion, ValidateHipVersion); - RegisterValidator("ROCBLAS_VERSION", GetRocBlasVersion, ValidateRocBlasVersion); - RegisterValidator( - "DEVICE_MODEL", - [this]() { return GetDeviceModel(); }, - [this](const std::string& value) { return ValidateDeviceModel(value); }); -} - -RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info) - : ITuningContext(ep), info_(info), validator_(ep) {} - -void RocmTuningContext::EnableTunableOp() { - LOGS_DEFAULT(INFO) << "Enable TunableOp for ROCm Execution Provider"; - info_->enable = true; -} - -void RocmTuningContext::DisableTunableOp() { - LOGS_DEFAULT(INFO) << "Disable TunableOp for ROCm Execution Provider"; - info_->enable = false; -} - -bool RocmTuningContext::IsTunableOpEnabled() const { - return info_->enable; -} - -void RocmTuningContext::EnableTuning() { - LOGS_DEFAULT(INFO) << "Enable TunableOp tuning for ROCm Execution Provider"; - info_->tuning_enable = true; -} - -void RocmTuningContext::DisableTuning() { - LOGS_DEFAULT(INFO) << "Disable TunableOp tuning for ROCm Execution Provider"; - info_->tuning_enable = false; -} - -bool RocmTuningContext::IsTuningEnabled() const { - return info_->tuning_enable; -} - -void RocmTuningContext::SetMaxTuningDurationMs(int max_duration_ms) { - info_->max_tuning_duration_ms = max_duration_ms; -} - -int RocmTuningContext::GetMaxTuningDurationMs() const { - return info_->max_tuning_duration_ms > 0 ? info_->max_tuning_duration_ms : std::numeric_limits::max(); -} - -TuningResultsManager& RocmTuningContext::GetTuningResultsManager() { - return manager_; -} - -const TuningResultsManager& RocmTuningContext::GetTuningResultsManager() const { - return manager_; -} - -const TuningResultsValidator& RocmTuningContext::GetTuningResultsValidator() const { - return validator_; -} - -IAllocatorUniquePtr RocmTuningContext::GetScratchBuffer( - size_t num_bytes, Stream* stream, OrtMemType mem_type) const { - if (num_bytes == 0) { - return nullptr; - } - - auto it = allocators_->find(ep_->GetOrtDeviceByMemType(mem_type)); - if (it == allocators_->end()) { - return nullptr; - } - - return IAllocator::MakeUniquePtr(it->second, num_bytes, false, stream, WaitRocmNotificationOnDevice); -} - -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h deleted file mode 100644 index ebad6b3ffc55b..0000000000000 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/framework/tuning_context.h" -#include "core/providers/rocm/rocm_execution_provider_info.h" - -namespace onnxruntime { - -class ROCMExecutionProvider; - -namespace rocm { -namespace tunable { - -class RocmTuningResultsValidator : public TuningResultsValidator { - public: - RocmTuningResultsValidator(ROCMExecutionProvider* ep); - - protected: - std::string GetOrtBuildConfig() const override; - - std::string GetDeviceModel() const; - Status ValidateDeviceModel(const std::string& value) const; - - private: - ROCMExecutionProvider* ep_; // non-owning handle -}; - -class RocmTuningContext : public ITuningContext { - public: - explicit RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info); - - void EnableTunableOp() override; - void DisableTunableOp() override; - bool IsTunableOpEnabled() const override; - - void EnableTuning() override; - void DisableTuning() override; - bool IsTuningEnabled() const override; - - void SetMaxTuningDurationMs(int max_duration_ms) override; - int GetMaxTuningDurationMs() const override; - - TuningResultsManager& GetTuningResultsManager() override; - const TuningResultsManager& GetTuningResultsManager() const override; - - const TuningResultsValidator& GetTuningResultsValidator() const override; - - IAllocatorUniquePtr GetScratchBuffer( - size_t bytes, Stream* stream, OrtMemType mem_type = OrtMemTypeDefault) const; - - private: - TunableOpInfo* info_; // non-owning handle - TuningResultsManager manager_; - RocmTuningResultsValidator validator_; -}; - -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/util.cc b/onnxruntime/core/providers/rocm/tunable/util.cc deleted file mode 100644 index 6ee046eb7fef9..0000000000000 --- a/onnxruntime/core/providers/rocm/tunable/util.cc +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/tunable/util.h" - -#include "core/providers/rocm/shared_inc/rocm_call.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { - -Timer::Timer(hipStream_t stream) : TimerBase(stream) { - HIP_CALL_THROW(hipEventCreate(&start_)); - HIP_CALL_THROW(hipEventCreate(&end_)); -} - -void Timer::Start() { - HIP_CALL_THROW(hipDeviceSynchronize()); - HIP_CALL_THROW(hipEventRecord(start_, stream_)); -} - -void Timer::End() { - HIP_CALL_THROW(hipEventRecord(end_, stream_)); - HIP_CALL_THROW(hipEventSynchronize(end_)); -} - -float Timer::Duration() { - float time; - // time is in ms with a resolution of 1 us - HIP_CALL_THROW(hipEventElapsedTime(&time, start_, end_)); - return time; -} - -Timer::~Timer() { - HIP_CALL_THROW(hipEventDestroy(start_)); - HIP_CALL_THROW(hipEventDestroy(end_)); -} - -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/util.h b/onnxruntime/core/providers/rocm/tunable/util.h deleted file mode 100644 index 36e1f52ce273b..0000000000000 --- a/onnxruntime/core/providers/rocm/tunable/util.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { - -class Timer : public ITimer { - public: - using TimerBase = ITimer; - - explicit Timer(hipStream_t stream); - - void Start() override; - void End() override; - float Duration() override; - ~Timer(); - - private: - hipEvent_t start_; - hipEvent_t end_; -}; - -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/version_script.lds b/onnxruntime/core/providers/rocm/version_script.lds deleted file mode 100644 index c02a8e4bcf724..0000000000000 --- a/onnxruntime/core/providers/rocm/version_script.lds +++ /dev/null @@ -1,10 +0,0 @@ -#_init and _fini should be local -VERS_1.0 { - global: - GetProvider; - _binary_*; - - # Hide everything else. - local: - *; -}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 32d93c305273d..14782b5a52262 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2959,6 +2959,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif if ((fp16_enable_ || bf16_enable_) && layer_norm_fp32_fallback_) { for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { auto layer = trt_network->getLayer(idx); @@ -2972,6 +2976,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } } } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif int num_inputs = trt_network->getNbInputs(); int num_outputs = trt_network->getNbOutputs(); @@ -3146,6 +3153,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif // Set precision flags std::string trt_node_name_with_precision = fused_node.Name(); if (fp16_enable_) { @@ -3163,7 +3174,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView trt_node_name_with_precision += "_int8"; LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; } - +#if defined(_MSC_VER) +#pragma warning(pop) +#endif // Set DLA if (fp16_enable_ || int8_enable_) { if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8 @@ -3779,7 +3792,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); } } - +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif // Set precision if (trt_state->int8_enable) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); @@ -3793,7 +3809,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] BF16 mode is enabled"; } - +#if defined(_MSC_VER) +#pragma warning(pop) +#endif // Set DLA (DLA can only run with FP16 or INT8) if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; diff --git a/onnxruntime/core/providers/webgpu/tensor/slice.cc b/onnxruntime/core/providers/webgpu/tensor/slice.cc index eb55134a31608..39432db5113d1 100644 --- a/onnxruntime/core/providers/webgpu/tensor/slice.cc +++ b/onnxruntime/core/providers/webgpu/tensor/slice.cc @@ -96,7 +96,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { // READ INPUTS const Tensor* input_tensor = context.Input(0); const TensorShape& input_shape = input_tensor->Shape(); - int64_t input_rank = static_cast(input_shape.NumDimensions()); + auto input_rank = input_shape.NumDimensions(); auto starts_raw = attr_starts_.empty() ? context.Input(1)->DataAsSpan() : gsl::make_span(attr_starts_); auto ends_raw = attr_ends_.empty() ? context.Input(2)->DataAsSpan() : gsl::make_span(attr_ends_); @@ -137,7 +137,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { } auto steps_raw = steps_tensor == nullptr ? gsl::make_span(steps_default) : steps_tensor->DataAsSpan(); - // PROCESS INPUTS + // get final axes std::vector axes, axes_fixed; for (unsigned int i = 0; i < axes_raw.size(); i++) { int64_t val = axes_raw[i]; @@ -149,46 +149,51 @@ Status Slice::ComputeInternal(ComputeContext& context) const { } std::vector starts; - for (unsigned int i = 0; i < starts_raw.size(); i++) { - int64_t val = starts_raw[i]; - if (val < 0) { - val += input_shape[axes[i]]; - } + std::vector ends; + std::vector signs; + std::vector steps; + std::vector output_dims; + output_dims.resize(input_rank, 0); - if (steps_raw[i] < 0) { - val = std::max(static_cast(0), std::min(val, static_cast(input_shape[axes[i]] - 1))); - } else { - val = std::max(static_cast(0), std::min(val, static_cast(input_shape[axes[i]]))); + // main loop over axes that will setup + // starts, ends, steps, signs and output_dims + for (unsigned int i = 0; i < starts_raw.size(); i++) { + int64_t start = starts_raw[i]; + int64_t end = ends_raw[i]; + int64_t step = steps_raw[i]; + int64_t dim_value = input_shape[axes[i]]; + if (start < 0) { + start += dim_value; } - starts.push_back(static_cast(val)); - } - - std::vector ends; - for (unsigned int i = 0; i < ends_raw.size(); i++) { - int64_t val = ends_raw[i]; - if (val < 0) { - val += input_shape[axes[i]]; + if (end == std::numeric_limits::max() || end == std::numeric_limits::max()) { + end = step < 0 ? -1 : dim_value; + } else if (end < 0) { + end += dim_value; } - if (steps_raw[i] < 0) { - val = std::max(static_cast(0), std::min(val, static_cast(input_shape[axes[i]] - 1))); + if (step < 0) { + // we are slicing in reverse + start = std::clamp(start, int64_t{0}, dim_value - 1); + end = std::clamp(end, int64_t{-1}, dim_value - 1); + // note that we are flipping start and end to switch to forward step + signs.push_back(-1); + steps.push_back(static_cast(-step)); + starts.push_back(static_cast((end < 0) ? 0 : end)); + ends.push_back(static_cast(start)); } else { - val = std::max(static_cast(0), std::min(val, static_cast(input_shape[axes[i]]))); + // we are slicing in forward direction + start = std::clamp(start, int64_t{0}, dim_value); + end = std::clamp(end, int64_t{0}, dim_value); + signs.push_back(1); + steps.push_back(static_cast(step)); + starts.push_back(static_cast(start)); + ends.push_back(static_cast(end)); } - ends.push_back(static_cast(val)); + auto temp = static_cast(ceil(1.0 * (end - start) / static_cast(step))); + output_dims[axes[i]] = (temp > 0 && dim_value != 0) ? temp : 0; } - // temporary steps vector to handle negative steps - std::vector steps_tmp; - for (unsigned int i = 0; i < steps_raw.size(); i++) { - if (steps_raw[i] >= std::numeric_limits::max()) { - steps_tmp.push_back(std::numeric_limits::max()); - } else { - steps_tmp.push_back(static_cast(steps_raw[i])); - } - } - - // Insert missing dimensions - if (static_cast(axes.size()) != input_rank) { + // insert missing dimensions + if (axes.size() != input_rank) { for (uint32_t i = 0; i < input_rank; i++) { int idx = -1; for (unsigned int j = 0; j < axes_fixed.size(); j++) { @@ -198,46 +203,24 @@ Status Slice::ComputeInternal(ComputeContext& context) const { } } if (idx == -1) { + uint32_t dim_value = static_cast(input_shape[i]); axes.insert(axes.begin() + i, i); starts.insert(starts.begin() + i, 0); - ends.insert(ends.begin() + i, static_cast(input_shape[i])); - steps_tmp.insert(steps_tmp.begin() + i, 1); + ends.insert(ends.begin() + i, dim_value); + signs.insert(signs.begin() + i, 1); + steps.insert(steps.begin() + i, 1); + output_dims[i] = dim_value; } } } - // retain the sign of the steps - std::vector signs; - for (unsigned int i = 0; i < steps_tmp.size(); i++) { - signs.push_back(steps_tmp[i] < 0 ? -1 : (steps_tmp[i] > 0 ? 1 : 0)); - } - - // Convert negative steps to positive steps and reverse starts and ends - for (unsigned int i = 0; i < steps_tmp.size(); i++) { - if (steps_tmp[i] < 0) { - float numSteps = static_cast((static_cast(ends[i]) - static_cast(starts[i])) / static_cast(steps_tmp[i])); - float newEnd = static_cast(starts[i]); - float newStart = newEnd + numSteps * static_cast(steps_tmp[i]); - - starts[i] = static_cast(newStart); - ends[i] = static_cast(newEnd); - steps_tmp[i] = static_cast(-steps_tmp[i]); - } - } - - // final steps vector of type unsigned int - std::vector steps; - for (unsigned int i = 0; i < steps_tmp.size(); i++) { - steps.push_back(static_cast(steps_tmp[i])); - } - // Reorder inputs in order of axis std::vector signs_reordered; std::vector steps_reordered, starts_reordered, ends_reordered; - signs_reordered.resize(static_cast(input_rank), 0); - steps_reordered.resize(static_cast(input_rank), 1); - starts_reordered.resize(static_cast(input_rank), 0); - ends_reordered.resize(static_cast(input_rank), 0); + signs_reordered.resize(input_rank, 0); + steps_reordered.resize(input_rank, 1); + starts_reordered.resize(input_rank, 0); + ends_reordered.resize(input_rank, 0); for (unsigned int i = 0; i < input_rank; i++) { int32_t dim = axes[i]; signs_reordered[dim] = signs[i]; @@ -246,16 +229,6 @@ Status Slice::ComputeInternal(ComputeContext& context) const { ends_reordered[dim] = ends[i]; } - // calculate output dims - std::vector output_dims; - for (unsigned int i = 0; i < input_rank; i++) { - float tmp = ceil((static_cast(ends_reordered[i]) - static_cast(starts_reordered[i])) / static_cast(steps_reordered[i])); - if (tmp < 0) - output_dims.push_back(0); - else - output_dims.push_back(static_cast(tmp)); - } - TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); @@ -275,4 +248,4 @@ Status Slice::ComputeInternal(ComputeContext& context) const { } } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 46e0347f0c0fd..13c746a6b1d31 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -855,6 +855,21 @@ std::unique_ptr WebGpuExecutionProvider::GetEx } #endif +std::optional WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain, + std::string_view node_op_type, + DataLayout target_data_layout) const { + if (target_data_layout != DataLayout::NHWC) { + return std::nullopt; + } + + // NHWC for Resize operator is not implemented on kWebGpuExecutionProvider + if (node_domain == kOnnxDomain && node_op_type == "Resize") { + return false; + } + + return std::nullopt; +} + WebGpuExecutionProvider::~WebGpuExecutionProvider() { WebGpuContextFactory::ReleaseContext(context_id_); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 15aec16210f16..2003f9b2ebcc6 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -57,6 +57,10 @@ class WebGpuExecutionProvider : public IExecutionProvider { DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } + std::optional ShouldConvertDataLayoutForOp(std::string_view node_domain, + std::string_view node_op_type, + DataLayout target_data_layout) const override; + FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } // WebGPU EP disallow concurrent run because actual implementation (eg. WebGPU backend) relies on global states to diff --git a/onnxruntime/core/session/abi_devices.h b/onnxruntime/core/session/abi_devices.h index 06041eb0086ac..8f9e8c20926fc 100644 --- a/onnxruntime/core/session/abi_devices.h +++ b/onnxruntime/core/session/abi_devices.h @@ -7,9 +7,13 @@ #include #include "core/common/hash_combine.h" +#include "core/framework/ortdevice.h" #include "core/session/abi_key_value_pairs.h" #include "core/session/onnxruntime_c_api.h" +// alias API type to internal type +struct OrtMemoryDevice : OrtDevice {}; + struct OrtHardwareDevice { OrtHardwareDeviceType type; uint32_t vendor_id; @@ -62,4 +66,6 @@ struct OrtEpDevice { OrtKeyValuePairs ep_options; OrtEpFactory* ep_factory; + const OrtMemoryInfo* device_memory_info{nullptr}; + const OrtMemoryInfo* host_accessible_memory_info{nullptr}; }; diff --git a/onnxruntime/core/session/abi_ep_types.cc b/onnxruntime/core/session/abi_ep_types.cc index 719f55b4e6b38..14764251898aa 100644 --- a/onnxruntime/core/session/abi_ep_types.cc +++ b/onnxruntime/core/session/abi_ep_types.cc @@ -10,7 +10,8 @@ #include "core/graph/ep_api_types.h" #include "core/session/abi_devices.h" -onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span nodes) { +onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span nodes, + const OrtNodeFusionOptions* optional_fusion_options) { std::vector ep_nodes; ep_nodes.reserve(nodes.size()); @@ -20,7 +21,8 @@ onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span ep_nodes; ep_nodes.push_back(onnxruntime::EpNode::ToInternal(node)); node_groupings.emplace_back(NodeGroupingKind::kSingleAssignedNode, std::move(ep_nodes)); - return onnxruntime::Status::OK(); } diff --git a/onnxruntime/core/session/abi_ep_types.h b/onnxruntime/core/session/abi_ep_types.h index b19a03a57a78a..eb68d79a24279 100644 --- a/onnxruntime/core/session/abi_ep_types.h +++ b/onnxruntime/core/session/abi_ep_types.h @@ -30,16 +30,19 @@ struct OrtEpGraphSupportInfo { // A grouping of supported nodes that should be handled in a single ComputeCapability. struct NodeGrouping { - NodeGrouping(NodeGroupingKind kind, std::vector&& nodes) - : kind(kind), nodes(std::move(nodes)) {} + NodeGrouping(NodeGroupingKind kind, std::vector&& nodes, + const OrtNodeFusionOptions& fusion_options = {}) + : kind(kind), nodes(std::move(nodes)), fusion_options(fusion_options) {} NodeGroupingKind kind = NodeGroupingKind::kInvalidGrouping; std::vector nodes; + OrtNodeFusionOptions fusion_options = {}; }; explicit OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph) : ort_graph(graph) {} - onnxruntime::Status AddNodesToFuse(gsl::span nodes); + onnxruntime::Status AddNodesToFuse(gsl::span nodes, + const OrtNodeFusionOptions* node_fusion_options = nullptr); onnxruntime::Status AddSingleNode(const OrtNode* node); const onnxruntime::EpGraph& ort_graph; diff --git a/onnxruntime/core/session/allocator_adapters.cc b/onnxruntime/core/session/allocator_adapters.cc index 5d1f84ba96cf2..9e38a0ef75ccc 100644 --- a/onnxruntime/core/session/allocator_adapters.cc +++ b/onnxruntime/core/session/allocator_adapters.cc @@ -3,7 +3,9 @@ #include "allocator_adapters.h" #include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" #include "core/session/abi_key_value_pairs.h" +#include "core/session/environment.h" #include "core/session/inference_session.h" #include "core/session/ort_env.h" #include "core/session/ort_apis.h" @@ -89,22 +91,30 @@ onnxruntime::AllocatorPtr OrtAllocatorImplWrappingIAllocator::GetWrappedIAllocat } IAllocatorImplWrappingOrtAllocator::IAllocatorImplWrappingOrtAllocator(OrtAllocator* ort_allocator) - : IAllocator(*ort_allocator->Info(ort_allocator)), ort_allocator_(ort_allocator) {} + : IAllocator(*ort_allocator->Info(ort_allocator)) { + ort_allocator_ = OrtAllocatorUniquePtr(ort_allocator, [](OrtAllocator*) { + // no-op + }); +} + +IAllocatorImplWrappingOrtAllocator::IAllocatorImplWrappingOrtAllocator(OrtAllocatorUniquePtr ort_allocator) + : IAllocator(*ort_allocator->Info(ort_allocator.get())), ort_allocator_(std::move(ort_allocator)) { +} void* IAllocatorImplWrappingOrtAllocator::Alloc(size_t size) { - return ort_allocator_->Alloc(ort_allocator_, size); + return ort_allocator_->Alloc(ort_allocator_.get(), size); } void* IAllocatorImplWrappingOrtAllocator::Reserve(size_t size) { if (ort_allocator_->version >= kOrtAllocatorReserveMinVersion && ort_allocator_->Reserve) { - return ort_allocator_->Reserve(ort_allocator_, size); + return ort_allocator_->Reserve(ort_allocator_.get(), size); } - return ort_allocator_->Alloc(ort_allocator_, size); + return ort_allocator_->Alloc(ort_allocator_.get(), size); } void IAllocatorImplWrappingOrtAllocator::Free(void* p) { - return ort_allocator_->Free(ort_allocator_, p); + return ort_allocator_->Free(ort_allocator_.get(), p); } void IAllocatorImplWrappingOrtAllocator::GetStats(AllocatorStats* stats) { @@ -112,7 +122,7 @@ void IAllocatorImplWrappingOrtAllocator::GetStats(AllocatorStats* stats) { if (ort_allocator_->version >= kOrtAllocatorStatsMinVersion && ort_allocator_->GetStats) { OrtKeyValuePairs* kvps = nullptr; - Ort::ThrowOnError(ort_allocator_->GetStats(ort_allocator_, &kvps)); + Ort::ThrowOnError(ort_allocator_->GetStats(ort_allocator_.get(), &kvps)); auto release_fn = [](OrtKeyValuePairs** kvp) { OrtApis::ReleaseKeyValuePairs(*kvp); @@ -161,11 +171,11 @@ ORT_API_STATUS_IMPL(OrtApis::CreateAllocator, const OrtSession* sess, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::CreateAndRegisterAllocator, _Inout_ OrtEnv* env, +ORT_API_STATUS_IMPL(OrtApis::CreateAndRegisterAllocator, _Inout_ OrtEnv* ort_env, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg) { using namespace onnxruntime; - if (!env) { + if (!ort_env) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Env is null"); } @@ -173,7 +183,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateAndRegisterAllocator, _Inout_ OrtEnv* env, return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtMemoryInfo is null"); } - auto st = env->CreateAndRegisterAllocator(*mem_info, arena_cfg); + auto& env = ort_env->GetEnvironment(); + auto st = env.CreateAndRegisterAllocator(*mem_info, arena_cfg); if (!st.IsOK()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, st.ErrorMessage().c_str()); @@ -181,10 +192,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateAndRegisterAllocator, _Inout_ OrtEnv* env, return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::RegisterAllocator, _Inout_ OrtEnv* env, +ORT_API_STATUS_IMPL(OrtApis::RegisterAllocator, _Inout_ OrtEnv* ort_env, _In_ OrtAllocator* allocator) { using namespace onnxruntime; - if (!env) { + if (!ort_env) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Env is null"); } @@ -200,10 +211,8 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterAllocator, _Inout_ OrtEnv* env, "allocators only."); } - std::shared_ptr i_alloc_ptr = - std::make_shared(allocator); - - auto st = env->RegisterAllocator(i_alloc_ptr); + auto& env = ort_env->GetEnvironment(); + auto st = env.RegisterAllocator(allocator); if (!st.IsOK()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, st.ErrorMessage().c_str()); @@ -211,10 +220,10 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterAllocator, _Inout_ OrtEnv* env, return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::UnregisterAllocator, _Inout_ OrtEnv* env, +ORT_API_STATUS_IMPL(OrtApis::UnregisterAllocator, _Inout_ OrtEnv* ort_env, _In_ const OrtMemoryInfo* mem_info) { using namespace onnxruntime; - if (!env) { + if (!ort_env) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Env is null"); } @@ -222,7 +231,8 @@ ORT_API_STATUS_IMPL(OrtApis::UnregisterAllocator, _Inout_ OrtEnv* env, return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Provided OrtMemoryInfo is null"); } - auto st = env->UnregisterAllocator(*mem_info); + auto& env = ort_env->GetEnvironment(); + auto st = env.UnregisterAllocator(*mem_info); if (!st.IsOK()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, st.ErrorMessage().c_str()); @@ -234,8 +244,11 @@ ORT_API(void, OrtApis::ReleaseAllocator, _Frees_ptr_opt_ OrtAllocator* allocator delete static_cast(allocator); } -ORT_API_STATUS_IMPL(OrtApis::CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg, - _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { +ORT_API_STATUS_IMPL(OrtApis::CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* ort_env, _In_ const char* provider_type, + _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys) { using namespace onnxruntime; std::unordered_map options; for (size_t i = 0; i != num_keys; i++) { @@ -252,7 +265,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, options[provider_options_keys[i]] = provider_options_values[i]; } - if (!env) { + if (!ort_env) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Env is null"); } @@ -260,10 +273,64 @@ ORT_API_STATUS_IMPL(OrtApis::CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtMemoryInfo is null"); } - auto st = env->CreateAndRegisterAllocatorV2(provider_type, *mem_info, options, arena_cfg); + auto& env = ort_env->GetEnvironment(); + auto st = env.CreateAndRegisterAllocatorV2(provider_type, *mem_info, options, arena_cfg); + return onnxruntime::ToOrtStatus(st); +} - if (!st.IsOK()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, st.ErrorMessage().c_str()); +ORT_API_STATUS_IMPL(OrtApis::GetSharedAllocator, _In_ OrtEnv* ort_env, _In_ const OrtMemoryInfo* mem_info, + _Outptr_result_maybenull_ OrtAllocator** allocator) { + *allocator = nullptr; + + if (ort_env == nullptr || mem_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtEnv and OrtMemoryInfo must be provided"); + } + + auto& env = ort_env->GetEnvironment(); + auto st = env.GetSharedAllocator(*mem_info, *allocator); + return onnxruntime::ToOrtStatus(st); +} + +ORT_API_STATUS_IMPL(OrtApis::CreateSharedAllocator, + [[maybe_unused]] _In_ OrtEnv* ort_env, + [[maybe_unused]] _In_ const OrtEpDevice* ep_device, + [[maybe_unused]] _In_ OrtDeviceMemoryType mem_type, + [[maybe_unused]] _In_ OrtAllocatorType allocator_type, + [[maybe_unused]] _In_opt_ const OrtKeyValuePairs* allocator_options, + _Outptr_opt_ OrtAllocator** allocator) { +#if !defined(ORT_MINIMAL_BUILD) + + if (ort_env == nullptr || ep_device == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtEnv and OrtEpDevice must be provided"); } + + auto& env = ort_env->GetEnvironment(); + ORT_API_RETURN_IF_STATUS_NOT_OK(env.CreateSharedAllocator(*ep_device, mem_type, allocator_type, allocator_options, + allocator)); + return nullptr; +#else + // there's no support for plugin EPs in a minimal build so you can't get an OrtEpDevice + *allocator = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); +#endif +} + +ORT_API_STATUS_IMPL(OrtApis::ReleaseSharedAllocator, + [[maybe_unused]] _In_ OrtEnv* ort_env, + [[maybe_unused]] _In_ const OrtEpDevice* ep_device, + [[maybe_unused]] _In_ OrtDeviceMemoryType mem_type) { +#if !defined(ORT_MINIMAL_BUILD) + if (ort_env == nullptr || ep_device == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtEnv and OrtEpDevice must be provided"); + } + + auto& env = ort_env->GetEnvironment(); + ORT_API_RETURN_IF_STATUS_NOT_OK(env.ReleaseSharedAllocator(*ep_device, mem_type)); + + return nullptr; +#else + // there's no support for plugin EPs in a minimal build so you can't get an OrtEpDevice + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); +#endif } diff --git a/onnxruntime/core/session/allocator_adapters.h b/onnxruntime/core/session/allocator_adapters.h index 8a180db75ec67..544c7828e46f8 100644 --- a/onnxruntime/core/session/allocator_adapters.h +++ b/onnxruntime/core/session/allocator_adapters.h @@ -25,11 +25,10 @@ struct OrtAllocatorImplWrappingIAllocator final : public OrtAllocatorImpl { ~OrtAllocatorImplWrappingIAllocator() override = default; void* Alloc(size_t size); - void Free(void* p); + void* Reserve(size_t size); const OrtMemoryInfo* Info() const; - void* Reserve(size_t size); std::unordered_map Stats() const; @@ -41,23 +40,32 @@ struct OrtAllocatorImplWrappingIAllocator final : public OrtAllocatorImpl { onnxruntime::AllocatorPtr i_allocator_; }; +using OrtAllocatorUniquePtr = std::unique_ptr>; + class IAllocatorImplWrappingOrtAllocator final : public IAllocator { public: + // ctor for OrtAllocator we do not own explicit IAllocatorImplWrappingOrtAllocator(OrtAllocator* ort_allocator); - ~IAllocatorImplWrappingOrtAllocator() override = default; + // ctor for OrtAllocator we own. + explicit IAllocatorImplWrappingOrtAllocator(OrtAllocatorUniquePtr ort_allocator); + + // ~IAllocatorImplWrappingOrtAllocator() override = default; void* Alloc(size_t size) override; + void Free(void* p) override; void* Reserve(size_t size) override; - void Free(void* p) override; + const OrtAllocator* GetWrappedOrtAllocator() const { + return ort_allocator_.get(); + } void GetStats(AllocatorStats* stats) override; ORT_DISALLOW_COPY_AND_ASSIGNMENT(IAllocatorImplWrappingOrtAllocator); private: - OrtAllocator* ort_allocator_ = nullptr; + OrtAllocatorUniquePtr ort_allocator_ = nullptr; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 20b7410045333..b3176b399756e 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -6,6 +6,7 @@ #include #include "core/common/basic_types.h" +#include "core/framework/allocator.h" #include "core/framework/allocator_utils.h" #include "core/framework/error_code_helper.h" #include "core/graph/constants.h" @@ -69,39 +70,22 @@ std::once_flag schemaRegistrationOnceFlag; ProviderInfo_CUDA& GetProviderInfo_CUDA(); #endif // defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) -Status Environment::Create(std::unique_ptr logging_manager, - std::unique_ptr& environment, - const OrtThreadingOptions* tp_options, - bool create_global_thread_pools) { - environment = std::make_unique(); - auto status = environment->Initialize(std::move(logging_manager), tp_options, create_global_thread_pools); - return status; -} - -// Ugly but necessary for instances where we want to check equality of two OrtMemoryInfos -// without accounting for OrtAllocatorType in the equality checking process. -// TODO: Should we remove the OrtAllocatorType field from the OrtMemoryInfo struct to -// avoid such problems and also remove the unintuitive phenomenon of binding an allocator -// type to OrtMemoryInfo (which loosely is just device info) ? +namespace { +// Ignore whether there is an arena wrapping the allocator by excluding OrtMemoryInfo.alloc_type from the comparison static bool AreOrtMemoryInfosEquivalent( const OrtMemoryInfo& left, const OrtMemoryInfo& right, - bool include_allocator_type_for_equivalence_checking = true) { - if (include_allocator_type_for_equivalence_checking) { - return left == right; - } else { - return left.mem_type == right.mem_type && - left.device == right.device && - strcmp(left.name, right.name) == 0; - } + bool match_name = true) { + return left.mem_type == right.mem_type && + left.device == right.device && + (!match_name || strcmp(left.name, right.name) == 0); } -Status Environment::RegisterAllocator(AllocatorPtr allocator) { - const auto& mem_info = allocator->Info(); - - // We don't expect millions of allocators getting registered. Hence linear search should be fine. - auto ite = std::find_if(std::begin(shared_allocators_), - std::end(shared_allocators_), - [&mem_info](const AllocatorPtr& alloc_ptr) { +std::vector::const_iterator FindExistingAllocator(const std::vector& allocators, + const OrtMemoryInfo& mem_info, + bool match_name = true) { + auto ite = std::find_if(std::begin(allocators), + std::end(allocators), + [&mem_info, match_name](const AllocatorPtr& alloc_ptr) { // We want to do the equality checking of 2 OrtMemoryInfos sans the OrtAllocatorType field. // This is because we want to avoid registering two allocators for the same device that just // differ on OrtAllocatorType. @@ -111,14 +95,91 @@ Status Environment::RegisterAllocator(AllocatorPtr allocator) { // OrtDeviceAllocator (which is the only accepted value while registering a custom allocator). // If we allowed this, it could potentially cause a lot of confusion as to which shared allocator // to use for that device and we want to avoid having any ugly logic around this. - return AreOrtMemoryInfosEquivalent(alloc_ptr->Info(), mem_info, false); + return AreOrtMemoryInfosEquivalent(alloc_ptr->Info(), mem_info, match_name); }); - if (ite != shared_allocators_.end()) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "An allocator for this device has already been registered for sharing."); + return ite; +} + +std::unordered_set::const_iterator FindExistingAllocator(const std::unordered_set& allocators, + const OrtMemoryInfo& mem_info, + bool match_name = true) { + return std::find_if(std::begin(allocators), + std::end(allocators), + [&mem_info, match_name](const OrtAllocator* alloc_ptr) { + const auto* alloc_mem_info = alloc_ptr->Info(alloc_ptr); + return AreOrtMemoryInfosEquivalent(*alloc_mem_info, mem_info, match_name); + }); +} +} // namespace + +Status Environment::Create(std::unique_ptr logging_manager, + std::unique_ptr& environment, + const OrtThreadingOptions* tp_options, + bool create_global_thread_pools) { + environment = std::make_unique(); + auto status = environment->Initialize(std::move(logging_manager), tp_options, create_global_thread_pools); + return status; +} + +Status Environment::RegisterAllocator(OrtAllocator* allocator) { + std::lock_guard lock{mutex_}; + + auto allocator_ptr = std::make_shared(allocator); + + // for the public API we always want to replace any existing allocator for the device. + auto status = RegisterAllocatorImpl(allocator_ptr); + + // update shared_ort_allocators_ + if (status.IsOK()) { + if (auto it = FindExistingAllocator(shared_ort_allocators_, *allocator->Info(allocator), /*match_name*/ true); + it != shared_ort_allocators_.end()) { + shared_ort_allocators_.erase(it); + } + + shared_ort_allocators_.insert(allocator); + } + + return status; +} + +Status Environment::RegisterAllocatorImpl(AllocatorPtr allocator) { + const auto& mem_info = allocator->Info(); + + const bool match_name = false; + if (FindExistingAllocator(shared_allocators_, mem_info, match_name) != shared_allocators_.end()) { + ORT_RETURN_IF_ERROR(UnregisterAllocatorImpl(mem_info, match_name)); + } + + shared_allocators_.push_back(std::move(allocator)); + + return Status::OK(); +} +Status Environment::UnregisterAllocator(const OrtMemoryInfo& mem_info) { + std::lock_guard lock{mutex_}; + + return UnregisterAllocatorImpl(mem_info); +} + +Status Environment::UnregisterAllocatorImpl(const OrtMemoryInfo& mem_info, bool error_if_not_found) { + auto it = FindExistingAllocator(shared_allocators_, mem_info); + + if (error_if_not_found && it == shared_allocators_.end()) { + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No allocator for this device has been registered for sharing."); + } + + // we need to remove from shared_ort_allocators_ first in case the entry in shared_allocators_ owns the pointer in + // shared_ort_allocators_ + // e.g. a plug-in EP allocator is an IAllocatorImplWrappingOrtAllocator that owns the OrtAllocator* created by the EP + // so when we remove that from shared_allocators_ we release the OrtAllocator instance. + + // shared_ort_allocators_ are internal only so never an error if there's no match + auto it2 = FindExistingAllocator(shared_ort_allocators_, mem_info); + if (it2 != shared_ort_allocators_.end()) { + shared_ort_allocators_.erase(it2); } - shared_allocators_.insert(ite, allocator); + shared_allocators_.erase(it); return Status::OK(); } @@ -126,7 +187,8 @@ Status Environment::RegisterAllocator(AllocatorPtr allocator) { Status Environment::CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, const OrtArenaCfg* arena_cfg) { // TODO should we allow sharing of non-CPU allocators? if (mem_info.device.Type() != OrtDevice::CPU) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only CPU devices are supported. Please call CreateAndRegisterAllocatorV2() for other device."); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Only CPU devices are supported. Please call CreateAndRegisterAllocatorV2() for other device."); } // determine if arena should be used @@ -148,15 +210,7 @@ Status Environment::CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, co // override with values from the user supplied arena_cfg object if (arena_cfg) { max_mem = arena_cfg->max_mem; - arena_extend_strategy = arena_cfg->arena_extend_strategy; - // validate the value here - if (!(arena_extend_strategy == -1 || arena_extend_strategy == 0 || arena_extend_strategy == 1)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Received invalid value for arena extend strategy." - " Valid values can be either 0, 1 or -1."); - } - initial_chunk_size_bytes = arena_cfg->initial_chunk_size_bytes; max_dead_bytes_per_chunk = arena_cfg->max_dead_bytes_per_chunk; initial_growth_chunk_size_bytes = arena_cfg->initial_growth_chunk_size_bytes; @@ -172,31 +226,13 @@ Status Environment::CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, co l_arena_cfg}; allocator_ptr = CreateAllocator(alloc_creation_info); } else { - AllocatorCreationInfo alloc_creation_info{[](int) { return std::make_unique(); }, + AllocatorCreationInfo alloc_creation_info{[mem_info](int) { return std::make_unique(mem_info); }, 0, create_arena}; allocator_ptr = CreateAllocator(alloc_creation_info); } - return RegisterAllocator(allocator_ptr); -} - -Status Environment::UnregisterAllocator(const OrtMemoryInfo& mem_info) { - auto ite = std::find_if(std::begin(shared_allocators_), - std::end(shared_allocators_), - [&mem_info](const AllocatorPtr& alloc_ptr) { - // See comment in RegisterAllocator() as to why we - // use this method of OrtMemoryInfo equality checking - return AreOrtMemoryInfosEquivalent(alloc_ptr->Info(), mem_info, false); - }); - - if (ite == shared_allocators_.end()) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT, - "No allocator for this device has been registered for sharing."); - } - - shared_allocators_.erase(ite); - - return Status::OK(); + std::lock_guard lock{mutex_}; + return RegisterAllocatorImpl(allocator_ptr); } Status Environment::Initialize(std::unique_ptr logging_manager, @@ -319,8 +355,6 @@ Internal copy node #if !defined(ORT_MINIMAL_BUILD) // register internal EPs for autoep selection - // TODO: ??? Is there any reason not to do this like an EP allocates a large chunk of memory when created? - // If that is the case the user could register by name with no library path to do registration manually. ORT_RETURN_IF_ERROR(CreateAndRegisterInternalEps()); #endif } @@ -363,15 +397,20 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ #if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) if (provider_type == onnxruntime::kCudaExecutionProvider) { - CUDAExecutionProviderInfo cuda_ep_info; - GetProviderInfo_CUDA().CUDAExecutionProviderInfo__FromProviderOptions(options, cuda_ep_info); - CUDAExecutionProviderExternalAllocatorInfo external_info = cuda_ep_info.external_allocator_info; - AllocatorPtr allocator_ptr = GetProviderInfo_CUDA().CreateCudaAllocator( - static_cast(mem_info.device.Id()), - arena_cfg->max_mem, - static_cast(arena_cfg->arena_extend_strategy), - external_info, arena_cfg); - return RegisterAllocator(allocator_ptr); + if (mem_info.device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) { + AllocatorPtr allocator_ptr = GetProviderInfo_CUDA().CreateCUDAPinnedAllocator(onnxruntime::CUDA_PINNED); + return RegisterAllocatorImpl(allocator_ptr); + } else { + CUDAExecutionProviderInfo cuda_ep_info; + GetProviderInfo_CUDA().CUDAExecutionProviderInfo__FromProviderOptions(options, cuda_ep_info); + CUDAExecutionProviderExternalAllocatorInfo external_info = cuda_ep_info.external_allocator_info; + AllocatorPtr allocator_ptr = GetProviderInfo_CUDA().CreateCudaAllocator( + static_cast(mem_info.device.Id()), + arena_cfg->max_mem, + static_cast(arena_cfg->arena_extend_strategy), + external_info, arena_cfg); + return RegisterAllocatorImpl(allocator_ptr); + } } #endif @@ -381,6 +420,27 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ Environment::~Environment() = default; +Status Environment::GetSharedAllocator(const OrtMemoryInfo& mem_info, OrtAllocator*& allocator) { + std::lock_guard lock{mutex_}; + + // doesn't matter whether we match a custom allocator or an EP allocator so match_name is false + auto it = FindExistingAllocator(shared_ort_allocators_, mem_info, /*match_name*/ false); + allocator = it != shared_ort_allocators_.end() ? *it : nullptr; + + // use the default CPU allocator if there's no custom or EP provided CPU allocator + if (!allocator && (mem_info.device.Type() == OrtDevice::CPU && + mem_info.device.MemType() == OrtDevice::MemType::DEFAULT)) { + if (!default_cpu_ort_allocator_) { + auto cpu_ort_allocator = std::make_unique(CPUAllocator::DefaultInstance()); + default_cpu_ort_allocator_ = std::move(cpu_ort_allocator); + } + + allocator = default_cpu_ort_allocator_.get(); + } + + return Status::OK(); +} + #if !defined(ORT_MINIMAL_BUILD) Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, std::unique_ptr ep_library, @@ -400,6 +460,19 @@ Status Environment::RegisterExecutionProviderLibrary(const std::string& registra execution_devices_.reserve(execution_devices_.size() + ep_info->execution_devices.size()); for (const auto& ed : ep_info->execution_devices) { execution_devices_.push_back(ed.get()); + + // add shared allocators so they're available without an inference session being required. + // we don't replace an existing allocator as we just need one to exist for the OrtMemoryInfo and we don't want + // to blow away any custom allocators previously added by the user. + if (ed->device_memory_info != nullptr) { + ORT_RETURN_IF_ERROR(CreateSharedAllocatorImpl(*ed, *ed->device_memory_info, OrtDeviceAllocator, nullptr, + nullptr, /*replace_existing*/ false)); + } + + if (ed->host_accessible_memory_info != nullptr) { + ORT_RETURN_IF_ERROR(CreateSharedAllocatorImpl(*ed, *ed->host_accessible_memory_info, OrtDeviceAllocator, + nullptr, nullptr, /*replace_existing*/ false)); + } } for (const auto& internal_factory : internal_factories) { @@ -432,6 +505,8 @@ Status Environment::CreateAndRegisterInternalEps() { } Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, const ORTCHAR_T* lib_path) { + std::lock_guard lock{mutex_}; + std::vector internal_factories = {}; std::unique_ptr ep_library; @@ -443,6 +518,8 @@ Status Environment::RegisterExecutionProviderLibrary(const std::string& registra } Status Environment::UnregisterExecutionProviderLibrary(const std::string& ep_name) { + std::lock_guard lock{mutex_}; + if (ep_libraries_.count(ep_name) == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution provider library: ", ep_name, " was not registered."); } @@ -450,9 +527,7 @@ Status Environment::UnregisterExecutionProviderLibrary(const std::string& ep_nam auto status = Status::OK(); ORT_TRY { - // unload. auto ep_info = std::move(ep_libraries_[ep_name]); - // remove from map and global list of OrtEpDevice* before unloading so we don't get a leftover entry if // something goes wrong in any of the following steps.. ep_libraries_.erase(ep_name); @@ -462,23 +537,135 @@ Status Environment::UnregisterExecutionProviderLibrary(const std::string& ep_nam } for (const auto& ed : ep_info->execution_devices) { + // remove from global list of OrtEpDevices if (auto it = std::find(execution_devices_.begin(), execution_devices_.end(), ed.get()); it != execution_devices_.end()) { execution_devices_.erase(it); } + + // unregister any shared allocators. + // match only the OrtEpDevice allocator in case the user registered a custom allocator with matching info. + const bool error_if_not_found = false; + if (ed->device_memory_info != nullptr) { + ORT_RETURN_IF_ERROR(UnregisterAllocatorImpl(*ed->device_memory_info, error_if_not_found)); + } + + if (ed->host_accessible_memory_info != nullptr) { + ORT_RETURN_IF_ERROR(UnregisterAllocatorImpl(*ed->host_accessible_memory_info, error_if_not_found)); + } } ep_info.reset(); } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to unregister EP library: ", ep_name, " with error: ", ex.what()); + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to unregister EP library: ", ep_name, " with error: ", + ex.what()); }); } return status; } +Status Environment::CreateSharedAllocator(const OrtEpDevice& ep_device, + OrtDeviceMemoryType mem_type, OrtAllocatorType allocator_type, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator_out) { + auto* memory_info = mem_type == OrtDeviceMemoryType_DEFAULT ? ep_device.device_memory_info + : ep_device.host_accessible_memory_info; + if (memory_info == nullptr) { + return Status(ONNXRUNTIME, ORT_INVALID_ARGUMENT, "Invalid memory type for OrtEpDevice."); + } + + std::lock_guard lock{mutex_}; + return CreateSharedAllocatorImpl(ep_device, *memory_info, allocator_type, allocator_options, allocator_out, + /*replace_existing*/ true); +} + +Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, + const OrtMemoryInfo& memory_info, + OrtAllocatorType allocator_type, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator_out, + bool replace_existing) { + // if we're replacing an existing allocator we don't care who added it + if (auto it = FindExistingAllocator(shared_allocators_, memory_info, /*match_name*/ false); + it != shared_allocators_.end()) { + if (!replace_existing) { + return Status::OK(); + } + + shared_allocators_.erase(it); + } + + // clear out any exact match in the internal shared allocators + if (auto it = FindExistingAllocator(shared_ort_allocators_, memory_info, /*match_name*/ true); + it != shared_ort_allocators_.end()) { + shared_ort_allocators_.erase(it); + } + + OrtAllocator* allocator = nullptr; + auto* ort_status = ep_device.ep_factory->CreateAllocator(ep_device.ep_factory, &memory_info, allocator_options, + &allocator); + if (ort_status != nullptr) { + return ToStatusAndRelease(ort_status); + } + + if (allocator_out != nullptr) { + *allocator_out = allocator; + } + + auto ort_allocator = OrtAllocatorUniquePtr(allocator, + [&ep_device](OrtAllocator* allocator) { + ep_device.ep_factory->ReleaseAllocator(ep_device.ep_factory, allocator); + }); + + AllocatorPtr shared_allocator; + + if (allocator_type == OrtArenaAllocator) { + // wrap with arena + OrtArenaCfg arena_cfg; + if (allocator_options != nullptr) { + auto status = OrtArenaCfg::FromKeyValuePairs(*allocator_options, arena_cfg); + } + + // pending Stream support being added to plugin EP API in separate PR + // ep_device.ep_factory->IsStreamAware(ep_device.ep_factory); + bool stream_aware_arena = false; + + AllocatorCreationInfo alloc_creation_info{ + [&ort_allocator](int) -> std::unique_ptr { + return std::make_unique(std::move(ort_allocator)); + }, + /*unused*/ -1, // arg to the lambda above that is ignored as the device id comes from the allocator + /*create_arena*/ true, + arena_cfg, + stream_aware_arena, + }; + + shared_allocator = CreateAllocator(alloc_creation_info); + } else { + shared_allocator = std::make_shared(std::move(ort_allocator)); + } + + shared_ort_allocators_.insert(allocator); + shared_allocators_.push_back(std::move(shared_allocator)); + + return Status::OK(); +} + +Status Environment::ReleaseSharedAllocator(const OrtEpDevice& ep_device, OrtDeviceMemoryType mem_type) { + auto* memory_info = mem_type == OrtDeviceMemoryType_DEFAULT ? ep_device.device_memory_info + : ep_device.host_accessible_memory_info; + if (memory_info == nullptr) { + return Status(ONNXRUNTIME, ORT_INVALID_ARGUMENT, "Invalid memory type for OrtEpDevice."); + } + + auto status = UnregisterAllocator(*memory_info); + + return status; +} + namespace { std::vector SortDevicesByType() { auto& devices = DeviceDiscovery::GetDevices(); diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/ep_api.cc index ffb5a286730ba..bbadfbee70656 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/ep_api.cc @@ -7,6 +7,10 @@ #include #include "core/framework/error_code_helper.h" #include "core/framework/func_api.h" +#include "core/framework/ort_value.h" +#include "core/framework/ortdevice.h" +#include "core/framework/ortmemoryinfo.h" +#include "core/framework/tensor.h" #include "core/graph/ep_api_types.h" #include "core/session/abi_devices.h" #include "core/session/abi_ep_types.h" @@ -44,7 +48,8 @@ ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device) { } ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* ort_graph_support_info, - _In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes) { + _In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes, + _In_opt_ const OrtNodeFusionOptions* node_fusion_options) { API_IMPL_BEGIN if (ort_graph_support_info == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid OrtGraph instance"); @@ -55,7 +60,7 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInf } gsl::span nodes_span(nodes, nodes + num_nodes); - ORT_API_RETURN_IF_STATUS_NOT_OK(ort_graph_support_info->AddNodesToFuse(nodes_span)); + ORT_API_RETURN_IF_STATUS_NOT_OK(ort_graph_support_info->AddNodesToFuse(nodes_span, node_fusion_options)); return nullptr; API_IMPL_END } @@ -85,6 +90,76 @@ ORT_API(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeConte return compute_context->node_name; } +ORT_API_STATUS_IMPL(EpDevice_AddAllocatorInfo, _In_ OrtEpDevice* ep_device, + _In_ const OrtMemoryInfo* allocator_memory_info) { + const OrtDevice& info = allocator_memory_info->device; + switch (info.MemType()) { + case OrtDevice::MemType::DEFAULT: + ep_device->device_memory_info = allocator_memory_info; + break; + case OrtDevice::MemType::HOST_ACCESSIBLE: + ep_device->host_accessible_memory_info = allocator_memory_info; + break; + default: + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Memory type must be DEFAULT or HOST_ACCESSIBLE."); + } + + return nullptr; +} + +ORT_API(const OrtMemoryDevice*, MemoryInfo_GetMemoryDevice, _In_ const OrtMemoryInfo* memory_info) { + return static_cast(&memory_info->device); +} + +ORT_API_STATUS_IMPL(Value_GetMemoryDevice, _In_ const OrtValue* value, _Out_ const OrtMemoryDevice** device) { + *device = nullptr; + if (value == nullptr || value->IsTensor() == false) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue does not contain an allocated tensor."); + } + + auto& tensor = value->Get(); + *device = static_cast(&tensor.Location().device); + + return nullptr; +} + +ORT_API(bool, MemoryDevice_AreEqual, _In_ const OrtMemoryDevice* a, _In_ const OrtMemoryDevice* b) { + // don't care if they're both null as you don't need to call this function if they are + if (a == nullptr || b == nullptr) { + return false; + } + + // TODO: Validate this calls OrtDevice::operator== as expected + return *a == *b; +} + +ORT_API(OrtMemoryInfoDeviceType, MemoryDevice_GetDeviceType, _In_ const OrtMemoryDevice* memory_device) { + switch (memory_device->Type()) { + case OrtDevice::GPU: + return OrtMemoryInfoDeviceType_GPU; + case OrtDevice::NPU: + return OrtMemoryInfoDeviceType_NPU; + case OrtDevice::FPGA: + return OrtMemoryInfoDeviceType_FPGA; + case OrtDevice::CPU: + default: // should never happen. means we're out of sync with CreateMemoryInfo_V2 + return OrtMemoryInfoDeviceType_CPU; + } +} + +ORT_API(OrtDeviceMemoryType, MemoryDevice_GetMemoryType, _In_ const OrtMemoryDevice* memory_device) { + return memory_device->MemType() == OrtDevice::MemType::DEFAULT ? OrtDeviceMemoryType_DEFAULT + : OrtDeviceMemoryType_HOST_ACCESSIBLE; +} + +ORT_API(uint32_t, MemoryDevice_GetVendorId, _In_ const OrtMemoryDevice* memory_device) { + return memory_device->Vendor(); +} + +ORT_API(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_device) { + return memory_device->Id(); +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -96,6 +171,16 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::EpGraphSupportInfo_AddNodesToFuse, &OrtExecutionProviderApi::EpGraphSupportInfo_AddSingleNode, &OrtExecutionProviderApi::NodeComputeContext_NodeName, + &OrtExecutionProviderApi::EpDevice_AddAllocatorInfo, + + &OrtExecutionProviderApi::MemoryInfo_GetMemoryDevice, + &OrtExecutionProviderApi::Value_GetMemoryDevice, + + &OrtExecutionProviderApi::MemoryDevice_AreEqual, + &OrtExecutionProviderApi::MemoryDevice_GetDeviceType, + &OrtExecutionProviderApi::MemoryDevice_GetMemoryType, + &OrtExecutionProviderApi::MemoryDevice_GetVendorId, + &OrtExecutionProviderApi::MemoryDevice_GetDeviceId, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/ep_api.h b/onnxruntime/core/session/ep_api.h index 84c8781a70adb..bc6b26dd30f34 100644 --- a/onnxruntime/core/session/ep_api.h +++ b/onnxruntime/core/session/ep_api.h @@ -18,9 +18,21 @@ ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory, ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device); ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info, - _In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes); + _In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes, + _In_opt_ const OrtNodeFusionOptions* node_fusion_options); ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddSingleNode, _In_ OrtEpGraphSupportInfo* graph_support_info, _In_ const OrtNode* node); ORT_API(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeContext* context); +ORT_API_STATUS_IMPL(EpDevice_AddAllocatorInfo, _In_ OrtEpDevice* ep_device, + _In_ const OrtMemoryInfo* allocator_memory_info); + +ORT_API(const OrtMemoryDevice*, MemoryInfo_GetMemoryDevice, _In_ const OrtMemoryInfo* memory_info); +ORT_API_STATUS_IMPL(Value_GetMemoryDevice, _In_ const OrtValue* value, _Out_ const OrtMemoryDevice** device); + +ORT_API(bool, MemoryDevice_AreEqual, _In_ const OrtMemoryDevice* a, _In_ const OrtMemoryDevice* b); +ORT_API(OrtMemoryInfoDeviceType, MemoryDevice_GetDeviceType, _In_ const OrtMemoryDevice* memory_device); +ORT_API(OrtDeviceMemoryType, MemoryDevice_GetMemoryType, _In_ const OrtMemoryDevice* memory_device); +ORT_API(uint32_t, MemoryDevice_GetVendorId, _In_ const OrtMemoryDevice* memory_device); +ORT_API(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_device); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/ep_api_utils.h index 23c25b4e7befb..366f934fc610e 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/ep_api_utils.h @@ -8,11 +8,11 @@ namespace onnxruntime { // used by EpFactoryInternal and EpFactoryProviderBridge. template struct ForwardToFactory { - static const char* ORT_API_CALL GetFactoryName(const OrtEpFactory* this_ptr) { + static const char* ORT_API_CALL GetFactoryName(const OrtEpFactory* this_ptr) noexcept { return static_cast(this_ptr)->GetName(); } - static const char* ORT_API_CALL GetVendor(const OrtEpFactory* this_ptr) { + static const char* ORT_API_CALL GetVendor(const OrtEpFactory* this_ptr) noexcept { return static_cast(this_ptr)->GetVendor(); } @@ -21,7 +21,7 @@ struct ForwardToFactory { size_t num_devices, OrtEpDevice** ep_devices, size_t max_ep_devices, - size_t* num_ep_devices) { + size_t* num_ep_devices) noexcept { return static_cast(this_ptr)->GetSupportedDevices(devices, num_devices, ep_devices, max_ep_devices, num_ep_devices); } @@ -32,12 +32,12 @@ struct ForwardToFactory { size_t num_devices, const OrtSessionOptions* session_options, const OrtLogger* logger, - OrtEp** ep) { + OrtEp** ep) noexcept { return static_cast(this_ptr)->CreateEp(devices, ep_metadata_pairs, num_devices, session_options, logger, ep); } - static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) { + static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { static_cast(this_ptr)->ReleaseEp(ep); } }; diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index 354e609a6301c..b906f25935983 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -33,7 +33,7 @@ OrtStatus* EpFactoryInternal::GetSupportedDevices(const OrtHardwareDevice* const size_t num_devices, OrtEpDevice** ep_devices, size_t max_ep_devices, - size_t* num_ep_devices) { + size_t* num_ep_devices) noexcept { return get_supported_func_(this, devices, num_devices, ep_devices, max_ep_devices, num_ep_devices); } diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index 3853949e94375..1951b51a38bee 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -37,14 +37,14 @@ class EpFactoryInternal : public OrtEpFactory { GetSupportedFunc&& get_supported_func, CreateFunc&& create_func); - const char* GetName() const { return ep_name_.c_str(); } - const char* GetVendor() const { return vendor_.c_str(); } + const char* GetName() const noexcept { return ep_name_.c_str(); } + const char* GetVendor() const noexcept { return vendor_.c_str(); } OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, _In_ size_t num_devices, _Inout_ OrtEpDevice** ep_devices, _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices); + _Out_ size_t* num_ep_devices) noexcept; // we don't implement this. CreateIExecutionProvider should be used. OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc index ebd74dd51774c..98e490a490c00 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc @@ -12,6 +12,7 @@ #include "core/framework/compute_capability.h" #include "core/framework/error_code_helper.h" #include "core/framework/model_metadef_id_generator.h" +#include "core/framework/plugin_data_transfer.h" #include "core/graph/ep_api_types.h" #include "core/graph/model_editor_api_types.h" #include "core/session/abi_devices.h" @@ -30,12 +31,13 @@ namespace onnxruntime { PluginExecutionProviderFactory::PluginExecutionProviderFactory(OrtEpFactory& ep_factory, gsl::span ep_devices) - : ep_factory_{ep_factory} { - devices_.reserve(ep_devices.size()); + : ep_factory_{ep_factory}, + devices_{ep_devices.begin(), ep_devices.end()} { + hardware_devices_.reserve(ep_devices.size()); ep_metadata_.reserve(ep_devices.size()); for (const auto* ep_device : ep_devices) { - devices_.push_back(ep_device->device); + hardware_devices_.push_back(ep_device->device); ep_metadata_.push_back(&ep_device->ep_metadata); } } @@ -44,15 +46,16 @@ std::unique_ptr PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_options, const OrtLogger& session_logger) { OrtEp* ort_ep = nullptr; - Status status = ToStatusAndRelease(ep_factory_.CreateEp(&ep_factory_, devices_.data(), ep_metadata_.data(), - devices_.size(), &session_options, &session_logger, &ort_ep)); + Status status = ToStatusAndRelease(ep_factory_.CreateEp(&ep_factory_, hardware_devices_.data(), ep_metadata_.data(), + hardware_devices_.size(), &session_options, &session_logger, + &ort_ep)); if (!status.IsOK()) { ORT_THROW("Error creating execution provider: ", status.ToString()); } auto ep_wrapper = std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)), - session_options); + session_options, ep_factory_, devices_); ep_wrapper->SetLogger(session_logger.ToInternal()); return ep_wrapper; @@ -84,10 +87,24 @@ struct PluginEpMetaDefNameFunctor { // PluginExecutionProvider // -PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options) +PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, + OrtEpFactory& ep_factory, + gsl::span ep_devices) : IExecutionProvider(ep->GetName(ep.get()), OrtDevice()), // TODO: What to do about OrtDevice for plugins? - ort_ep_(std::move(ep)) { + ort_ep_(std::move(ep)), + ep_factory_(ep_factory), + ep_devices_(ep_devices.begin(), ep_devices.end()) { generate_ep_ctx_model_ = session_options.value.GetEpContextGenerationOptions().enable; + + for (const auto* ep_device : ep_devices_) { + if (ep_device->device_memory_info != nullptr) { + allocator_mem_infos_.push_back(ep_device->device_memory_info); + } + + if (ep_device->host_accessible_memory_info != nullptr) { + allocator_mem_infos_.push_back(ep_device->host_accessible_memory_info); + } + } } PluginExecutionProvider::~PluginExecutionProvider() { @@ -138,6 +155,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie } else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) { std::unordered_set node_set; node_set.reserve(node_grouping.nodes.size()); + for (const EpNode* ep_node : node_grouping.nodes) { node_set.insert(&ep_node->GetInternalNode()); } @@ -151,7 +169,8 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // unsupported nodes in any path between supported nodes. std::vector> capabilities = utils::CreateSupportedPartitions( graph_viewer, node_set, /*stop_ops*/ {}, PluginEpMetaDefNameFunctor(generator, graph_viewer, this->Type()), - this->Type(), this->Type(), /*node_unit_map*/ nullptr); + this->Type(), this->Type(), /*node_unit_map*/ nullptr, + node_grouping.fusion_options.drop_constant_initializers); if (capabilities.size() > 1) { LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() set nodes that cannot be fused together. " @@ -271,8 +290,8 @@ static Status ConvertEpContextNodes(const std::string& ep_name, const std::vecto #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) } -common::Status PluginExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_infos) { +Status PluginExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_infos) { const logging::Logger* logger = GetLogger(); const size_t num_graphs = fused_nodes_and_graphs.size(); std::vector> api_graphs_holder; @@ -390,4 +409,139 @@ const InlinedVector PluginExecutionProvider::GetEpContextNodes() co return result; } +namespace { + +struct DataLayoutMapping { + DataLayout data_layout; + OrtEpDataLayout api_data_layout; +}; + +// Maps enum values between `onnxruntime::DataLayout` and `OrtEpDataLayout`. +constexpr std::array kDataLayoutMappings{ + DataLayoutMapping{DataLayout::NCHW, OrtEpDataLayout::OrtEpDataLayout_NCHW}, + DataLayoutMapping{DataLayout::NHWC, OrtEpDataLayout::OrtEpDataLayout_NHWC}, +}; + +} // namespace + +DataLayout PluginExecutionProvider::GetPreferredLayout() const { + if (ort_ep_->GetPreferredDataLayout == nullptr) { + return Base::GetPreferredLayout(); + } + + OrtEpDataLayout api_data_layout{}; + + ORT_THROW_IF_ERROR(ToStatusAndRelease(ort_ep_->GetPreferredDataLayout(ort_ep_.get(), &api_data_layout))); + + const auto data_layout_mapping = std::find_if(kDataLayoutMappings.begin(), kDataLayoutMappings.end(), + [api_data_layout](const DataLayoutMapping& mapping) { + return mapping.api_data_layout == api_data_layout; + }); + + ORT_ENFORCE(data_layout_mapping != kDataLayoutMappings.end(), + "OrtEp::GetPreferredDataLayout() returned an invalid data layout: ", static_cast(api_data_layout)); + + return data_layout_mapping->data_layout; +} + +std::optional PluginExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain, + std::string_view node_op_type, + DataLayout target_data_layout) const { + if (ort_ep_->ShouldConvertDataLayoutForOp == nullptr) { + return Base::ShouldConvertDataLayoutForOp(node_domain, node_op_type, target_data_layout); + } + + const auto data_layout_mapping = std::find_if(kDataLayoutMappings.begin(), kDataLayoutMappings.end(), + [target_data_layout](const DataLayoutMapping& mapping) { + return mapping.data_layout == target_data_layout; + }); + + ORT_ENFORCE(data_layout_mapping != kDataLayoutMappings.end(), + "Unable to map target_data_layout (", static_cast(target_data_layout), ") to OrtEpDataLayout."); + + // Ensure domain and op type strings are null-terminated. + const std::string node_domain_str{node_domain}, node_op_type_str{node_op_type}; + int should_convert = -1; + + ORT_THROW_IF_ERROR(ToStatusAndRelease( + ort_ep_->ShouldConvertDataLayoutForOp(ort_ep_.get(), + node_domain_str.c_str(), node_op_type_str.c_str(), + data_layout_mapping->api_data_layout, + &should_convert))); + + if (should_convert > 0) { + return true; + } else if (should_convert == 0) { + return false; + } else { + return std::nullopt; + } +} + +Status PluginExecutionProvider::OnRunStart(const RunOptions& run_options) { + if (ort_ep_->OnRunStart == nullptr) { + return Base::OnRunStart(run_options); + } + + return ToStatusAndRelease(ort_ep_->OnRunStart(ort_ep_.get(), &run_options)); +} + +Status PluginExecutionProvider::OnRunEnd(bool sync_stream, const RunOptions& run_options) { + if (ort_ep_->OnRunEnd == nullptr) { + return Base::OnRunEnd(sync_stream, run_options); + } + + return ToStatusAndRelease(ort_ep_->OnRunEnd(ort_ep_.get(), &run_options, sync_stream)); +} + +Status PluginExecutionProvider::SetEpDynamicOptions(gsl::span keys, + gsl::span values) { + if (ort_ep_->SetDynamicOptions == nullptr) { + return Base::SetEpDynamicOptions(keys, values); + } + + ORT_RETURN_IF_NOT(keys.size() == values.size(), + "The number of keys (", keys.size(), ") and number of values (", values.size(), + ") must be the same."); + + return ToStatusAndRelease(ort_ep_->SetDynamicOptions(ort_ep_.get(), keys.data(), values.data(), keys.size())); +} +std::unique_ptr PluginExecutionProvider::GetDataTransfer() const { + OrtDataTransferImpl* data_transfer_impl = nullptr; + OrtStatus* status = ep_factory_.CreateDataTransfer(&ep_factory_, &data_transfer_impl); + if (status != nullptr) { + ORT_THROW("Error creating data transfer: ", ToStatusAndRelease(status).ToString()); + } + + if (data_transfer_impl == nullptr) { + return {}; + } + + return std::make_unique(*data_transfer_impl); +} + +std::vector PluginExecutionProvider::CreatePreferredAllocators() { + std::vector allocators; + allocators.reserve(allocator_mem_infos_.size()); + + for (const auto* memory_info : allocator_mem_infos_) { + OrtAllocator* ort_allocator_ptr = nullptr; + OrtStatus* ort_status = ep_factory_.CreateAllocator(&ep_factory_, memory_info, nullptr, &ort_allocator_ptr); + + // throw or log? start with throw + if (ort_status != nullptr) { + ORT_THROW("Error creating allocator: ", ToStatusAndRelease(ort_status).ToString()); + } + + auto ort_allocator = OrtAllocatorUniquePtr( + ort_allocator_ptr, + [this](OrtAllocator* allocator) { + ep_factory_.ReleaseAllocator(&ep_factory_, allocator); + }); + allocators.push_back(std::make_shared(std::move(ort_allocator))); + } + + return allocators; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/ep_plugin_provider_interfaces.h index 2b88c7f5d494f..343d6c9ad464e 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.h @@ -35,7 +35,8 @@ struct PluginExecutionProviderFactory : public IExecutionProviderFactory { private: OrtEpFactory& ep_factory_; - std::vector devices_; + std::vector devices_; + std::vector hardware_devices_; std::vector ep_metadata_; }; @@ -59,8 +60,12 @@ using UniqueOrtEp = std::unique_ptr; /// IExecutionProvider that wraps an instance of OrtEp. /// class PluginExecutionProvider : public IExecutionProvider { + private: + using Base = IExecutionProvider; + public: - explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options); + explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, + gsl::span ep_devices); ~PluginExecutionProvider(); std::vector> @@ -69,11 +74,31 @@ class PluginExecutionProvider : public IExecutionProvider { const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant = nullptr) const override; - common::Status Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) override; + Status Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) override; + + DataLayout GetPreferredLayout() const override; + + std::optional ShouldConvertDataLayoutForOp(std::string_view node_domain, + std::string_view node_op_type, + DataLayout target_data_layout) const override; + + Status OnRunStart(const RunOptions& run_options) override; + + Status OnRunEnd(bool sync_stream, const RunOptions& run_options) override; + + Status SetEpDynamicOptions(gsl::span keys, + gsl::span values) override; const InlinedVector GetEpContextNodes() const override; + std::unique_ptr GetDataTransfer() const override; + + // create per-session allocators + // longer term we should prefer shared allocators in Environment and only create per-session allocators as + // needed based on matching against allocator_mem_infos_. + std::vector CreatePreferredAllocators() override; + private: struct FusedNodeState { FusedNodeState() = default; @@ -86,7 +111,11 @@ class PluginExecutionProvider : public IExecutionProvider { }; UniqueOrtEp ort_ep_; + OrtEpFactory& ep_factory_; + std::vector ep_devices_; + std::vector allocator_mem_infos_; bool generate_ep_ctx_model_ = false; + std::vector api_node_compute_infos_; // Fused nodes have to be valid throughout model inference because they may be cached in NodeComputeInfo instances. diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 468639a9f25bb..86a61a4d0ee74 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3127,6 +3127,9 @@ Status InferenceSession::Run(const RunOptions& run_options, LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture."; ORT_RETURN_IF_ERROR(Run(run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info)); } + + // Log runtime error telemetry if the return value is not OK + ORT_RETURN_IF_ERROR_SESSIONID(retval, session_id_); return retval; } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 7670197455a9d..4e25187ff1b47 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -630,6 +630,10 @@ class InferenceSession { return weight_hash_; } + uint32_t GetCurrentSessionId() const { + return session_id_; + } + protected: #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 2a6dc31344f6b..15f86cf0d7002 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -31,6 +31,7 @@ #include "core/graph/constants.h" #include "core/graph/graph.h" #include "core/graph/model_editor_api_types.h" +#include "core/graph/ep_api_types.h" #include "core/providers/get_execution_providers.h" #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" @@ -114,12 +115,12 @@ using namespace onnxruntime; #define TENSOR_READ_API_BEGIN \ API_IMPL_BEGIN \ auto v = reinterpret_cast(value); \ - auto& tensor = v->Get(); + const auto& tensor = v->Get(); #define TENSOR_READWRITE_API_BEGIN \ API_IMPL_BEGIN \ auto v = (value); \ - auto tensor = v->GetMutable(); + auto* tensor = v->GetMutable(); namespace { // Create tensor. Allocates memory. Tensor owns memory. Allocator is wrapped and stored in a shared_ptr in Tensor. @@ -785,7 +786,7 @@ ORT_API_STATUS_IMPL(OrtApis::SetEpDynamicOptions, _Inout_ OrtSession* sess, Status status; if (kv_len == 0) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "no imputs were passed"); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "no inputs were passed"); } else { status = session->SetEpDynamicOptions(keys_span, values_span); @@ -1090,6 +1091,13 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorMutableData, _Inout_ OrtValue* value, _Out API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::GetTensorData, _Inout_ const OrtValue* value, _Outptr_ const void** output) { + TENSOR_READ_API_BEGIN + *output = tensor.DataRaw(); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::FillStringTensor, _Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len) { TENSOR_READWRITE_API_BEGIN auto* dst = tensor->MutableData(); @@ -2239,6 +2247,11 @@ ORT_API_STATUS_IMPL(OrtApis::CreateArenaCfg, _In_ size_t max_mem, int arena_exte cfg->initial_chunk_size_bytes = initial_chunk_size_bytes; cfg->max_dead_bytes_per_chunk = max_dead_bytes_per_chunk; cfg->max_dead_bytes_per_chunk = -1L; + + if (!cfg->IsValid()) { + return CreateStatus(ORT_INVALID_ARGUMENT, "Invalid configuration value was provided."); + } + *out = cfg.release(); return nullptr; API_IMPL_END @@ -2270,6 +2283,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateArenaCfgV2, _In_reads_(num_keys) const char* } } + if (!cfg->IsValid()) { + return CreateStatus(ORT_INVALID_ARGUMENT, "Invalid configuration value was provided."); + } + *out = cfg.release(); return nullptr; API_IMPL_END @@ -2842,6 +2859,96 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetImplicitInputs, _In_ const OrtNode* node, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attributes) { + API_IMPL_BEGIN + if (attributes == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attributes' argument is NULL"); + } + + std::unique_ptr array; + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetAttributes(array)); + + *attributes = array.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute) { + API_IMPL_BEGIN + if (attribute == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attribute' argument is NULL"); + } + + const EpNode* ep_node = EpNode::ToInternal(node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetAttributeByName."); + } + + *attribute = ep_node->GetAttribute(attribute_name); + + if (*attribute) { + return nullptr; + } else { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist."); + } + + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type) { + API_IMPL_BEGIN + const auto attr = attribute->attr_proto; + auto onnx_attr_type = attribute->attr_proto.type(); + switch (onnx_attr_type) { + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_UNDEFINED: { + *type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + break; + } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT: { + *type = OrtOpAttrType::ORT_OP_ATTR_INT; + break; + } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS: { + *type = OrtOpAttrType::ORT_OP_ATTR_INTS; + break; + } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT: { + *type = OrtOpAttrType::ORT_OP_ATTR_FLOAT; + break; + } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS: { + *type = OrtOpAttrType::ORT_OP_ATTR_FLOATS; + break; + } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING: { + *type = OrtOpAttrType::ORT_OP_ATTR_STRING; + break; + } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS: { + *type = OrtOpAttrType::ORT_OP_ATTR_STRINGS; + break; + } + default: + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type."); + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name) { + API_IMPL_BEGIN + if (name == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "name argument is null"); + } + if (attribute == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); + } + + *name = attribute->attr_proto.name().c_str(); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs) { API_IMPL_BEGIN if (subgraphs == nullptr) { @@ -3069,6 +3176,10 @@ ORT_API(const OrtHardwareDevice*, OrtApis::EpDevice_Device, _In_ const OrtEpDevi return ep_device->device; } +ORT_API(const OrtMemoryInfo*, OrtApis::EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device) { + return ep_device->device_memory_info; +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -3501,6 +3612,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::GetTensorSizeInBytes, &OrtApis::AllocatorGetStats, + &OrtApis::CreateMemoryInfo_V2, &OrtApis::CreateArrayOfConstObjects, @@ -3537,9 +3649,22 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetInputs, &OrtApis::Node_GetOutputs, &OrtApis::Node_GetImplicitInputs, + &OrtApis::Node_GetAttributes, + &OrtApis::Node_GetAttributeByName, + &OrtApis::OpAttr_GetType, + &OrtApis::OpAttr_GetName, &OrtApis::Node_GetSubgraphs, &OrtApis::Node_GetParentGraph, + &OrtApis::GetRunConfigEntry, + + &OrtApis::EpDevice_MemoryInfo, + + &OrtApis::CreateSharedAllocator, + &OrtApis::GetSharedAllocator, + &OrtApis::ReleaseSharedAllocator, + + &OrtApis::GetTensorData, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 32319152d7e01..8e6734b914be2 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -662,8 +662,28 @@ ORT_API_STATUS_IMPL(Node_GetSinceVersion, _In_ const OrtNode* node, _Out_ int* s ORT_API_STATUS_IMPL(Node_GetInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** inputs); ORT_API_STATUS_IMPL(Node_GetOutputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** outputs); ORT_API_STATUS_IMPL(Node_GetImplicitInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** implicit_inputs); +ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attrs); +ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute); +ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); +ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs); ORT_API_STATUS_IMPL(Node_GetParentGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** parent_graph); +ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options, + _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value); + +ORT_API(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device); + +ORT_API_STATUS_IMPL(CreateSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDevice* ep_device, + _In_ OrtDeviceMemoryType mem_type, _In_ OrtAllocatorType allocator_type, + _In_opt_ const OrtKeyValuePairs* allocator_options, + _Outptr_opt_ OrtAllocator** allocator); +ORT_API_STATUS_IMPL(GetSharedAllocator, _In_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, + _Outptr_result_maybenull_ OrtAllocator** allocator); + +ORT_API_STATUS_IMPL(ReleaseSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDevice* ep_device, + _In_ OrtDeviceMemoryType mem_type); + +ORT_API_STATUS_IMPL(GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** out); } // namespace OrtApis diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index 57d97d1b862d6..1bd8a18d7255f 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -105,10 +105,17 @@ void OrtEnv::Release(OrtEnv* env_ptr) { instance_to_delete = p_instance_; // Point to the instance to be deleted. p_instance_ = nullptr; // Set the static instance pointer to nullptr under the lock. } else { +#if !defined(ONNXRUNTIME_ENABLE_MEMLEAK_CHECK) // Process is shutting down, let it leak. // p_instance_ remains as is (though ref_count_ is 0), future CreateEnv calls // would increment ref_count_ on this "leaked" instance. // This behavior matches the requirement to "just let the memory leak out". +#else + // we're tracing for memory leaks so we want to avoid as many leaks as possible and the leaks are considered + // as failures for test apps. + instance_to_delete = p_instance_; + p_instance_ = nullptr; +#endif } } } // Mutex m_ is released here when lock_guard goes out of scope. @@ -125,22 +132,3 @@ onnxruntime::logging::LoggingManager* OrtEnv::GetLoggingManager() const { void OrtEnv::SetLoggingManager(std::unique_ptr logging_manager) { value_->SetLoggingManager(std::move(logging_manager)); } - -onnxruntime::common::Status OrtEnv::RegisterAllocator(AllocatorPtr allocator) { - auto status = value_->RegisterAllocator(allocator); - return status; -} - -onnxruntime::common::Status OrtEnv::CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, - const OrtArenaCfg* arena_cfg) { - auto status = value_->CreateAndRegisterAllocator(mem_info, arena_cfg); - return status; -} - -onnxruntime::common::Status OrtEnv::UnregisterAllocator(const OrtMemoryInfo& mem_info) { - return value_->UnregisterAllocator(mem_info); -} - -onnxruntime::common::Status OrtEnv::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg) { - return value_->CreateAndRegisterAllocatorV2(provider_type, mem_info, options, arena_cfg); -} diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index 20ac1b633e29d..94c8e0a6ea2e8 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -48,26 +48,8 @@ struct OrtEnv { onnxruntime::logging::LoggingManager* GetLoggingManager() const; void SetLoggingManager(std::unique_ptr logging_manager); - /** - * Registers an allocator for sharing between multiple sessions. - * Returns an error if an allocator with the same OrtMemoryInfo is already registered. - */ - onnxruntime::common::Status RegisterAllocator(onnxruntime::AllocatorPtr allocator); - - /** - * Creates and registers an allocator for sharing between multiple sessions. - * Return an error if an allocator with the same OrtMemoryInfo is already registered. - */ - onnxruntime::common::Status CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, - const OrtArenaCfg* arena_cfg = nullptr); - - /** - * Removes registered allocator that was previously registered for sharing between multiple sessions. - */ - onnxruntime::common::Status UnregisterAllocator(const OrtMemoryInfo& mem_info); OrtEnv(std::unique_ptr value); ~OrtEnv(); - onnxruntime::common::Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg = nullptr); private: // p_instance_ holds the single, global instance of OrtEnv. diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 422668ef1a27f..2a1f7580ac3aa 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -63,11 +63,6 @@ #include "orttraining/core/framework/distributed_run_context.h" #endif -#if defined(USE_ROCM) && defined(ORT_USE_NCCL) && defined(ENABLE_TRAINING) -#include "orttraining/training_ops/rocm/communication/nccl_service.h" -#include "orttraining/core/framework/distributed_run_context.h" -#endif - #ifdef _WIN32 #include "core/platform/windows/logging/etw_sink.h" #endif @@ -101,7 +96,6 @@ using EtwRegistrationManager_EtwInternalCallback = EtwRegistrationManager::EtwIn #include "core/providers/cuda/cuda_provider_factory_creator.h" #include "core/providers/cann/cann_provider_factory_creator.h" -#include "core/providers/rocm/rocm_provider_factory_creator.h" #include "core/providers/dnnl/dnnl_provider_factory_creator.h" #include "core/providers/migraphx/migraphx_provider_factory_creator.h" #include "core/providers/openvino/openvino_provider_factory_creator.h" @@ -112,7 +106,6 @@ using EtwRegistrationManager_EtwInternalCallback = EtwRegistrationManager::EtwIn #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cann/cann_provider_factory.h" -#include "core/providers/rocm/rocm_provider_factory.h" #include "core/providers/dnnl/dnnl_provider_factory.h" #include "core/providers/migraphx/migraphx_provider_factory.h" #include "core/providers/openvino/openvino_provider_factory.h" @@ -156,8 +149,6 @@ ProviderInfo_CANN* TryGetProviderInfo_CANN(); ProviderInfo_CANN& GetProviderInfo_CANN(); ProviderInfo_Dnnl* TryGetProviderInfo_Dnnl(); ProviderInfo_Dnnl& GetProviderInfo_Dnnl(); -ProviderInfo_ROCM* TryGetProviderInfo_ROCM(); -ProviderInfo_ROCM& GetProviderInfo_ROCM(); ProviderHostCPU& GetProviderHostCPU(); ProviderInfo_MIGraphX* TryGetProviderInfo_MIGraphX(); ProviderInfo_MIGraphX& GetProviderInfo_MIGraphX(); @@ -306,22 +297,7 @@ struct ProviderHostImpl : ProviderHost { std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_MIGraphX().CreateMIGraphXPinnedAllocator(device_id, name); } #endif -#ifdef USE_ROCM - std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_ROCM().CreateROCMAllocator(device_id, name); } - std::unique_ptr CreateROCMPinnedAllocator(const char* name) override { return GetProviderInfo_ROCM().CreateROCMPinnedAllocator(name); } - std::unique_ptr CreateGPUDataTransfer() override { return GetProviderInfo_ROCM().CreateGPUDataTransfer(); } - - void rocm__Impl_Cast(void* stream, const int64_t* input_data, int32_t* output_data, size_t count) override { return GetProviderInfo_ROCM().rocm__Impl_Cast(stream, input_data, output_data, count); } - void rocm__Impl_Cast(void* stream, const int32_t* input_data, int64_t* output_data, size_t count) override { return GetProviderInfo_ROCM().rocm__Impl_Cast(stream, input_data, output_data, count); } - - void rocm__Impl_Cast(void* stream, const double* input_data, float* output_data, size_t count) override { return GetProviderInfo_ROCM().rocm__Impl_Cast(stream, input_data, output_data, count); } - void rocm__Impl_Cast(void* stream, const float* input_data, double* output_data, size_t count) override { return GetProviderInfo_ROCM().rocm__Impl_Cast(stream, input_data, output_data, count); } - - Status RocmCall_false(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) override { return GetProviderInfo_ROCM().RocmCall_false(retCode, exprString, libName, successCode, msg, file, line); } - void RocmCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) override { GetProviderInfo_ROCM().RocmCall_true(retCode, exprString, libName, successCode, msg, file, line); } -#else std::unique_ptr CreateGPUDataTransfer() override { return GetProviderInfo_CUDA().CreateGPUDataTransfer(); } -#endif std::string GetEnvironmentVar(const std::string& var_name) override { return Env::Default().GetEnvironmentVar(var_name); } @@ -1912,12 +1888,7 @@ static ProviderLibrary s_library_cann(LIBRARY_PREFIX ORT_TSTR("onnxruntime_provi false /* unload - On Linux if we unload the cann shared provider we crash */ #endif ); -static ProviderLibrary s_library_rocm(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_rocm") LIBRARY_EXTENSION -#ifndef _WIN32 - , - false /* unload - On Linux if we unload the rocm shared provider we crash */ -#endif -); + static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_dnnl") LIBRARY_EXTENSION); static ProviderLibrary s_library_vitisai(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_vitisai") LIBRARY_EXTENSION #ifndef _WIN32 @@ -1963,7 +1934,6 @@ void UnloadSharedProviders() { s_library_cuda.Unload(); s_library_cuda_test.Unload(); s_library_cann.Unload(); - s_library_rocm.Unload(); s_library_shared.Unload(); s_library_migraphx.Unload(); s_library_qnn.Unload(); @@ -1978,13 +1948,6 @@ std::unique_ptr CreateCUDAPinnedAllocator(const char* name) { return nullptr; } -std::unique_ptr CreateROCMPinnedAllocator(const char* name) { - if (auto* info = onnxruntime::TryGetProviderInfo_ROCM()) - return info->CreateROCMPinnedAllocator(name); - - return nullptr; -} - // Adapter to convert the legacy OrtCUDAProviderOptions to the latest OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(const OrtCUDAProviderOptions* legacy_cuda_options) { OrtCUDAProviderOptionsV2 cuda_options_converted{}; @@ -2029,10 +1992,6 @@ std::shared_ptr CudaProviderFactoryCreator::Create( return nullptr; } -std::shared_ptr RocmProviderFactoryCreator::Create(const OrtROCMProviderOptions* provider_options) { - return s_library_rocm.Get().CreateExecutionProviderFactory(provider_options); -} - std::shared_ptr CannProviderFactoryCreator::Create(const OrtCANNProviderOptions* provider_options) { return s_library_cann.Get().CreateExecutionProviderFactory(provider_options); @@ -2326,20 +2285,6 @@ ProviderInfo_Dnnl& GetProviderInfo_Dnnl() { ORT_THROW("oneDNN Provider not available, can't get interface for it"); } -ProviderInfo_ROCM* TryGetProviderInfo_ROCM() try { - return reinterpret_cast(s_library_rocm.Get().GetInfo()); -} catch (const std::exception& exception) { - LOGS_DEFAULT(ERROR) << exception.what(); - return nullptr; -} - -ProviderInfo_ROCM& GetProviderInfo_ROCM() { - if (auto* info = TryGetProviderInfo_ROCM()) - return *info; - - ORT_THROW("ROCM Provider not available, can't get interface for it"); -} - ProviderInfo_MIGraphX* TryGetProviderInfo_MIGraphX() try { return reinterpret_cast(s_library_migraphx.Get().GetInfo()); } catch (const std::exception& exception) { @@ -2362,16 +2307,12 @@ void CopyGpuToCpu( const OrtMemoryInfo& src_location) { if (auto* info = onnxruntime::TryGetProviderInfo_CUDA()) return info->CopyGpuToCpu(dst_ptr, src_ptr, size, dst_location, src_location); - if (auto* info = onnxruntime::TryGetProviderInfo_ROCM()) - return info->CopyGpuToCpu(dst_ptr, src_ptr, size, dst_location, src_location); ORT_THROW("GPU-to-CPU copy is not implemented."); } void cudaMemcpy_HostToDevice(void* dst, const void* src, size_t count) { if (auto* info = onnxruntime::TryGetProviderInfo_CUDA()) return info->cudaMemcpy_HostToDevice(dst, src, count); - if (auto* info = onnxruntime::TryGetProviderInfo_ROCM()) - return info->rocmMemcpy_HostToDevice(dst, src, count); ORT_THROW("cudaMemcpy_HostToDevice is not implemented."); } @@ -2395,14 +2336,6 @@ INcclService& INcclService::GetInstance() { } // namespace cuda #endif -#if defined(USE_ROCM) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) && defined(ENABLE_TRAINING) -namespace rocm { -INcclService& INcclService::GetInstance() { - return GetProviderInfo_ROCM().GetINcclService(); -} -} // namespace rocm -#endif - void UpdateProviderInfo_Tensorrt(OrtTensorRTProviderOptionsV2* provider_options, const ProviderOptions& options) { s_library_tensorrt.Get().UpdateProviderOptions(reinterpret_cast(provider_options), options); } @@ -2567,12 +2500,7 @@ ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, [[maybe_unused]] _In_ int de return info->SetCurrentGpuDeviceId(device_id); #endif -#ifdef USE_ROCM - if (auto* info = onnxruntime::TryGetProviderInfo_ROCM()) - return info->SetCurrentGpuDeviceId(device_id); -#endif - - return CreateStatus(ORT_FAIL, "CUDA and/or ROCM execution provider is either not enabled or not available."); + return CreateStatus(ORT_FAIL, "CUDA execution provider is either not enabled or not available."); API_IMPL_END } @@ -2584,12 +2512,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, [[maybe_unused]] _In_ int* d return info->GetCurrentGpuDeviceId(device_id); #endif -#ifdef USE_ROCM - if (auto* info = onnxruntime::TryGetProviderInfo_ROCM()) - return info->GetCurrentGpuDeviceId(device_id); -#endif - - return CreateStatus(ORT_FAIL, "CUDA and/or ROCM execution provider is either not enabled or not available."); + return CreateStatus(ORT_FAIL, "CUDA execution provider is either not enabled or not available."); API_IMPL_END } @@ -2605,25 +2528,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_CUDA, _In_ Or API_IMPL_END } -ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessionOptions* options, int device_id) { - OrtROCMProviderOptions provider_options{}; - provider_options.device_id = device_id; - - return OrtApis::SessionOptionsAppendExecutionProvider_ROCM(options, &provider_options); -} - -ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options) { - API_IMPL_BEGIN - auto factory = onnxruntime::RocmProviderFactoryCreator::Create(rocm_options); - if (!factory) { - return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Rocm: Failed to load shared library"); - } - - options->provider_factories.push_back(factory); - return nullptr; - API_IMPL_END -} - ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options) { API_IMPL_BEGIN @@ -2735,8 +2639,7 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptions, #if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) || \ defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || \ defined(USE_CANN) || \ - defined(USE_DNNL) || \ - defined(USE_ROCM) + defined(USE_DNNL) static std::string BuildOptionsString(const onnxruntime::ProviderOptions::iterator& begin, const onnxruntime::ProviderOptions::iterator& end) { std::ostringstream options; @@ -3136,86 +3039,6 @@ ORT_API(void, OrtApis::ReleaseDnnlProviderOptions, _Frees_ptr_opt_ OrtDnnlProvid #endif } -ORT_API_STATUS_IMPL(OrtApis::CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out) { - API_IMPL_BEGIN -#ifdef USE_ROCM - auto options = std::make_unique(); - options->device_id = 0; - options->miopen_conv_exhaustive_search = 0; - options->gpu_mem_limit = std::numeric_limits::max(); - options->arena_extend_strategy = 0; - options->do_copy_in_default_stream = 1; - options->has_user_compute_stream = 0; - options->user_compute_stream = nullptr; - options->default_memory_arena_cfg = nullptr; - options->enable_hip_graph = false; - options->tunable_op_enable = 0; - options->tunable_op_tuning_enable = 0; - options->tunable_op_max_tuning_duration_ms = 0; - - *out = options.release(); - return nullptr; -#else - ORT_UNUSED_PARAMETER(out); - return CreateStatus(ORT_FAIL, "ROCm execution provider is not enabled in this build."); -#endif - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::UpdateROCMProviderOptions, - _Inout_ OrtROCMProviderOptions* rocm_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - size_t num_keys) { - API_IMPL_BEGIN -#ifdef USE_ROCM - onnxruntime::ProviderOptions provider_options_map; - for (size_t i = 0; i != num_keys; ++i) { - if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' || - provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "key/value cannot be empty"); - } - - provider_options_map[provider_options_keys[i]] = provider_options_values[i]; - } - - onnxruntime::s_library_rocm.Get().UpdateProviderOptions(rocm_options, provider_options_map); - return nullptr; -#else - ORT_UNUSED_PARAMETER(rocm_options); - ORT_UNUSED_PARAMETER(provider_options_keys); - ORT_UNUSED_PARAMETER(provider_options_values); - ORT_UNUSED_PARAMETER(num_keys); - return CreateStatus(ORT_FAIL, "ROCm execution provider is not enabled in this build."); -#endif - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::GetROCMProviderOptionsAsString, _In_ const OrtROCMProviderOptions* rocm_options, - _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr) { - API_IMPL_BEGIN -#ifdef USE_ROCM - onnxruntime::ProviderOptions options = onnxruntime::s_library_rocm.Get().GetProviderOptions(rocm_options); - std::string options_str = BuildOptionsString(options.begin(), options.end()); - *ptr = onnxruntime::StrDup(options_str, allocator); - return nullptr; -#else - ORT_UNUSED_PARAMETER(rocm_options); - ORT_UNUSED_PARAMETER(allocator); - ORT_UNUSED_PARAMETER(ptr); - return CreateStatus(ORT_FAIL, "ROCm execution provider is not enabled in this build."); -#endif - API_IMPL_END -} - -ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions* ptr) { -#ifdef USE_ROCM - std::unique_ptr p(ptr); -#else - ORT_UNUSED_PARAMETER(ptr); -#endif -} - ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index edd937c870260..a5258a4811bf7 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -214,6 +214,64 @@ Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const "No execution providers selected. Please check the device policy and available devices."); } + // Log telemetry for auto EP selection + { + std::vector requested_ep_ids; + requested_ep_ids.reserve(devices_selected.size()); + + for (const auto* device : devices_selected) { + if (device != nullptr) { + requested_ep_ids.push_back(device->ep_name); + } + } + + // Extract available execution provider IDs + std::vector available_ep_ids; + available_ep_ids.reserve(execution_devices.size()); + for (const auto* device : execution_devices) { + available_ep_ids.push_back(device->ep_name); + } + + std::string policy_type; + if (options.value.ep_selection_policy.delegate) { + policy_type = "custom_delegate"; + } else { + switch (options.value.ep_selection_policy.policy) { + case OrtExecutionProviderDevicePolicy_DEFAULT: + policy_type = "DEFAULT"; + break; + case OrtExecutionProviderDevicePolicy_PREFER_CPU: + policy_type = "PREFER_CPU"; + break; + case OrtExecutionProviderDevicePolicy_PREFER_NPU: + policy_type = "PREFER_NPU"; + break; + case OrtExecutionProviderDevicePolicy_PREFER_GPU: + policy_type = "PREFER_GPU"; + break; + case OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE: + policy_type = "MAX_PERFORMANCE"; + break; + case OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY: + policy_type = "MAX_EFFICIENCY"; + break; + case OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER: + policy_type = "MIN_OVERALL_POWER"; + break; + default: + policy_type = "UNKNOWN"; + break; + } + } + + const Env& os_env = Env::Default(); + os_env.GetTelemetryProvider().LogAutoEpSelection( + sess.GetCurrentSessionId(), + policy_type, + requested_ep_ids, + available_ep_ids); + } + // Configure the session options for the devices. This updates the SessionOptions in the InferenceSession with any // EP options that have not been overridden by the user. ORT_RETURN_IF_ERROR(AddEpDefaultOptionsToSession(sess, devices_selected)); @@ -264,7 +322,11 @@ void ProviderPolicyContext::FoldSelectedDevices(std::vector }); if (iter != devices_selected.end()) { - info.devices.push_back((*iter)->device); + info.devices.push_back(*iter); + // hardware device and metadata come from the OrtEpDevice but we need a collection of just the pointers + // to pass through to the CreateEp call. other info in the OrtEpDevice is used on the ORT side like the + // allocator and data transfer setup. + info.hardware_devices.push_back((*iter)->device); info.ep_metadata.push_back(&(*iter)->ep_metadata); devices_selected.erase(iter); } else { @@ -284,15 +346,16 @@ Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, Or if (internal_factory) { // this is a factory we created and registered internally for internal and provider bridge EPs ORT_RETURN_IF_ERROR(ToStatusAndRelease( - internal_factory->CreateIExecutionProvider(info.devices.data(), info.ep_metadata.data(), - info.devices.size(), &options, &logger, + internal_factory->CreateIExecutionProvider(info.hardware_devices.data(), info.ep_metadata.data(), + info.hardware_devices.size(), &options, &logger, &ep))); } else { OrtEp* api_ep = nullptr; - ORT_RETURN_IF_ERROR(ToStatusAndRelease(info.ep_factory->CreateEp(info.ep_factory, info.devices.data(), - info.ep_metadata.data(), info.devices.size(), - &options, &logger, &api_ep))); - ep = std::make_unique(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)), options); + ORT_RETURN_IF_ERROR(ToStatusAndRelease( + info.ep_factory->CreateEp(info.ep_factory, info.hardware_devices.data(), info.ep_metadata.data(), + info.hardware_devices.size(), &options, &logger, &api_ep))); + ep = std::make_unique(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)), options, + *info.ep_factory, info.devices); } return Status::OK(); diff --git a/onnxruntime/core/session/provider_policy_context.h b/onnxruntime/core/session/provider_policy_context.h index 185f9523312ba..295ac21ca4aa5 100644 --- a/onnxruntime/core/session/provider_policy_context.h +++ b/onnxruntime/core/session/provider_policy_context.h @@ -13,7 +13,8 @@ namespace onnxruntime { struct SelectionInfo { OrtEpFactory* ep_factory; - std::vector devices; + std::vector devices; + std::vector hardware_devices; std::vector ep_metadata; }; diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 1f74ee3b3f2ee..18a463ef69943 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -322,11 +322,9 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, API_IMPL_END } -#if defined(__APPLE__) || defined(ORT_MINIMAL_BUILD) static OrtStatus* CreateNotEnabledStatus(const std::string& ep) { return OrtApis::CreateStatus(ORT_FAIL, (ep + " execution provider is not enabled in this build. ").c_str()); } -#endif /** * Stubs for the publicly exported static registration functions for EPs that are referenced in the C# bindings @@ -445,13 +443,6 @@ ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, _In_ int device_id) { return CreateNotEnabledStatus("CUDA"); } -ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, - _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* provider_options) { - ORT_UNUSED_PARAMETER(options); - ORT_UNUSED_PARAMETER(provider_options); - return CreateNotEnabledStatus("ROCM"); -} - ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options) { ORT_UNUSED_PARAMETER(options); @@ -617,6 +608,23 @@ ORT_API(void, OrtApis::ReleaseDnnlProviderOptions, _Frees_ptr_opt_ OrtDnnlProvid ORT_UNUSED_PARAMETER(ptr); } +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(provider_options_keys); + ORT_UNUSED_PARAMETER(provider_options_values); + ORT_UNUSED_PARAMETER(num_keys); + return CreateNotEnabledStatus("VitisAI"); +} +#endif +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, + _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* provider_options) { + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(provider_options); + return CreateNotEnabledStatus("ROCM"); +} + ORT_API_STATUS_IMPL(OrtApis::CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out) { ORT_UNUSED_PARAMETER(out); return CreateNotEnabledStatus("ROCM"); @@ -647,14 +655,3 @@ ORT_API_STATUS_IMPL(OrtApis::GetROCMProviderOptionsAsString, ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions* ptr) { ORT_UNUSED_PARAMETER(ptr); } - -ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, - _In_ OrtSessionOptions* options, _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { - ORT_UNUSED_PARAMETER(options); - ORT_UNUSED_PARAMETER(provider_options_keys); - ORT_UNUSED_PARAMETER(provider_options_values); - ORT_UNUSED_PARAMETER(num_keys); - return CreateNotEnabledStatus("VitisAI"); -} -#endif diff --git a/onnxruntime/python/tools/kernel_explorer/README.md b/onnxruntime/python/tools/kernel_explorer/README.md deleted file mode 100644 index a3adae45df508..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/README.md +++ /dev/null @@ -1,51 +0,0 @@ -# Kernel Explorer - -Kernel Explorer hooks up GPU kernel code with a Python frontend to help develop, test, profile, and auto-tune GPU kernels. The initial scope is for BERT-like models with ROCM EP. - -## Build - -```bash -#!/bin/bash - -set -ex - -build_dir="build" -config="Release" - -rocm_home="/opt/rocm" - -./build.sh --update \ - --build_dir ${build_dir} \ - --config ${config} \ - --cmake_extra_defines \ - CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ \ - onnxruntime_BUILD_KERNEL_EXPLORER=ON \ - --skip_submodule_sync --skip_tests \ - --use_rocm --rocm_home=${rocm_home} --nccl_home=${rocm_home} \ - --build_wheel - -cmake --build ${build_dir}/${config} --target kernel_explorer --parallel -``` - -## Run - -Taking `vector_add_test.py` and build configuration with `build_dir="build"` and `config="Release"` in the previous section as an example. - -Set up the native library search path with the following environment variable: -```bash -export KERNEL_EXPLORER_BUILD_DIR=`realpath build/Release` -``` - -To test a kernel implementation, `pip install pytest` and then - -```bash -pytest onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py -``` - -To run the microbenchmarks: - -```bash -python onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py -``` - -Currently, kernel explorer mainly targets kernel developers, not the onnxruntime package end users, so it is not installed via `setup.py`. diff --git a/onnxruntime/python/tools/kernel_explorer/device_array.h b/onnxruntime/python/tools/kernel_explorer/device_array.h deleted file mode 100644 index c3e502ece5a9f..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/device_array.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#ifdef USE_CUDA -#include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/tunable/util.h" - -#define CALL_THROW CUDA_CALL_THROW -#define MALLOC cudaMalloc -#define FREE cudaFree -#define MEMCPY cudaMemcpy -#define MEMCPY_HOST_TO_DEVICE cudaMemcpyHostToDevice -#define MEMCPY_DEVICE_TO_HOST cudaMemcpyDeviceToHost -#elif USE_ROCM -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/util.h" - -#define CALL_THROW HIP_CALL_THROW -#define MALLOC hipMalloc -#define FREE hipFree -#define MEMCPY hipMemcpy -#define MEMCPY_HOST_TO_DEVICE hipMemcpyHostToDevice -#define MEMCPY_DEVICE_TO_HOST hipMemcpyDeviceToHost -#endif - -namespace py = pybind11; - -namespace onnxruntime { - -class DeviceArray { - public: - DeviceArray(size_t ptr, ssize_t size, ssize_t itemsize) - : host_{reinterpret_cast(ptr)}, size_{size}, itemsize_{itemsize} { - void* dev_ptr; - CALL_THROW(MALLOC(&dev_ptr, size_ * itemsize_)); - device_.reset(dev_ptr, [](void* dev_ptr) { CALL_THROW(FREE(dev_ptr)); }); - CALL_THROW(MEMCPY(device_.get(), host_, size_ * itemsize_, MEMCPY_HOST_TO_DEVICE)); - } - explicit DeviceArray(py::array x) : DeviceArray(x.request()) {} - DeviceArray(const DeviceArray&) = default; - DeviceArray& operator=(const DeviceArray&) = default; - - void UpdateHostNumpyArray() { - CALL_THROW(MEMCPY(host_, device_.get(), size_ * itemsize_, MEMCPY_DEVICE_TO_HOST)); - } - - void UpdateDeviceArray() { - CALL_THROW(MEMCPY(device_.get(), host_, size_ * itemsize_, MEMCPY_HOST_TO_DEVICE)); - } - - void* ptr() const { - return device_.get(); - } - - private: - explicit DeviceArray(py::buffer_info buf) : DeviceArray(reinterpret_cast(buf.ptr), buf.size, buf.itemsize) {} - - std::shared_ptr device_; - void* host_; - py::ssize_t size_; - py::ssize_t itemsize_; -}; - -} // namespace onnxruntime - -#undef CALL_THROW -#undef MALLOC -#undef FREE -#undef MEMCPY -#undef MEMCPY_HOST_TO_DEVICE -#undef MEMCPY_DEVICE_TO_HOST diff --git a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc deleted file mode 100644 index 5eb05edefdcfc..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include -#include - -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -namespace py = pybind11; - -namespace onnxruntime { - -static py::module::module_def _kernel_explorer_module_def; - -bool TuningInfo::collect_enabled_{false}; -std::vector TuningInfo::collected_tuning_results_ = {}; -std::optional TuningInfo::max_tuning_duration_ms_ = {}; - -py::module GetKernelExplorerModule() { - static pybind11::module_ m = []() { - auto tmp = pybind11::module_::create_extension_module( - "_kernel_explorer", "", &_kernel_explorer_module_def); - tmp.dec_ref(); - return tmp; - }(); - return m; -} - -PYBIND11_PLUGIN_IMPL(_kernel_explorer) { - PYBIND11_CHECK_PYTHON_VERSION; - PYBIND11_ENSURE_INTERNALS_READY; - return GetKernelExplorerModule().ptr(); -} - -KE_REGISTER(m) { - py::class_(m, "DeviceArray") - .def(py::init()) - .def(py::init()) - .def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray) - .def("UpdateDeviceArray", &DeviceArray::UpdateDeviceArray); - - m.def("enable_collect_tuning_results", TuningInfo::EnableCollect, pybind11::arg("enable") = true); - - m.def("max_tuning_duration_ms", TuningInfo::SetMaxTuningDurationMs); - - m.def("get_collected_tuning_results", []() { - py::list ret; - for (const auto& trs : TuningInfo::GetCollectedTuningResults()) { - py::dict py_trs; - py_trs["ep"] = trs.ep; - py_trs["results"] = trs.results; - py_trs["validators"] = trs.validators; - ret.append(std::move(py_trs)); - } - return ret; - }); - - // clang-format ill-format the following code below version 18 - // clang-format off - m.def("is_composable_kernel_available", []() { -#ifdef USE_COMPOSABLE_KERNEL - return true; -#else - return false; -#endif - }); - - m.def("is_hipblaslt_available", []() { -#ifdef USE_HIPBLASLT - return true; -#else - return false; -#endif - }); - - m.def("is_float8_available", []() { -#ifndef DISABLE_FLOAT8_TYPES - return true; -#else - return false; -#endif - }); - // clang-format on -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h b/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h deleted file mode 100644 index 1c7232e6a5cd0..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/shared_library/provider_api.h" -#ifdef USE_CUDA -#include -#include "core/providers/cuda/cuda_execution_provider.h" -#include "core/providers/cuda/tunable/cuda_tunable.h" -#include "core/providers/cuda/tunable/util.h" -#elif USE_ROCM -#include -#include "core/providers/rocm/rocm_execution_provider.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "core/providers/rocm/tunable/util.h" -#endif - -#ifdef USE_CUDA -using onnxruntime::cuda::tunable::Timer; -using ExecutionProvider = onnxruntime::CUDAExecutionProvider; -using ExecutionProviderInfo = onnxruntime::CUDAExecutionProviderInfo; -using StreamT = cudaStream_t; -using TuningContextT = onnxruntime::cuda::tunable::CudaTuningContext; -#elif USE_ROCM -using onnxruntime::rocm::tunable::Timer; -using ExecutionProvider = onnxruntime::ROCMExecutionProvider; -using ExecutionProviderInfo = onnxruntime::ROCMExecutionProviderInfo; -using StreamT = hipStream_t; -using TuningContextT = onnxruntime::rocm::tunable::RocmTuningContext; -#else -#error "kernel explorer only supports CUDA or ROCM" -#endif - -namespace onnxruntime { - -struct TuningInfo { - static void EnableCollect(bool b) { - collect_enabled_ = b; - } - - static std::vector GetCollectedTuningResults() { - return collected_tuning_results_; - } - - static void SetMaxTuningDurationMs(int milliseconds) { - max_tuning_duration_ms_ = milliseconds; - } - - static bool collect_enabled_; - static std::vector collected_tuning_results_; - static std::optional max_tuning_duration_ms_; -}; - -/// Wrapping around Op and TunableOp -class IKernelExplorer { - public: - virtual void Run() = 0; - - void SetRepeats(int n) { - repeats_ = n; - } - - float Profile() { - // warm up - for (int i = 0; i < 5; i++) { - Run(); - } - Timer timer{static_cast(Stream()->GetHandle())}; - timer.Start(); - for (int i = 0; i < repeats_; i++) { - Run(); - } - timer.End(); - return timer.Duration() / repeats_; - } - - virtual ~IKernelExplorer() { - if (TuningInfo::collect_enabled_) { - TuningInfo::collected_tuning_results_.emplace_back(this->ep_->GetTuningContext()->GetTuningResults()); - } - } - - protected: - ExecutionProvider* GetEp() { - std::call_once(ep_create_once_, [this]() { - ExecutionProviderInfo info{}; - this->ep_ = std::make_unique(info); - auto allocators = this->ep_->CreatePreferredAllocators(); - for (auto& alloc : allocators) { - this->allocators_.insert({alloc->Info().device, alloc}); - } - auto tuning_ctx = this->ep_->GetTuningContext(); - if (nullptr != tuning_ctx) { - tuning_ctx->RegisterAllocatorsView(&this->allocators_); - for (const auto& tr : TuningInfo::collected_tuning_results_) { - auto status = tuning_ctx->LoadTuningResults(tr); - if (!status.IsOK()) { - LOGS_DEFAULT(ERROR) << status; - } - } - if (TuningInfo::max_tuning_duration_ms_.has_value()) { - tuning_ctx->SetMaxTuningDurationMs(*TuningInfo::max_tuning_duration_ms_); - } - } - stream_ = std::make_unique(nullptr, this->ep_->GetOrtDeviceByMemType(OrtMemTypeDefault)); - }); - return ep_.get(); - } - - TuningContextT* TuningContext() { - return static_cast(GetEp()->GetTuningContext()); - } - - onnxruntime::Stream* Stream() { return stream_.get(); } - - private: - std::once_flag ep_create_once_; - std::unique_ptr ep_{}; - std::map allocators_; - OrtDevice dev_; - std::unique_ptr stream_; - int repeats_{100}; -}; - -class WithMaxTuningDurationMs { - public: - WithMaxTuningDurationMs(TuningContextT* ctx, int ms) : ctx_(ctx) { - original_tuning_duration_ = ctx_->GetMaxTuningDurationMs(); - ctx_->SetMaxTuningDurationMs(ms); - } - - ~WithMaxTuningDurationMs() { - ctx_->SetMaxTuningDurationMs(original_tuning_duration_); - } - - private: - TuningContextT* ctx_; - int original_tuning_duration_; -}; - -pybind11::module GetKernelExplorerModule(); - -class KernelExplorerInit { - public: - explicit KernelExplorerInit(void (*init_func)(pybind11::module module)) { - init_func(GetKernelExplorerModule()); - } -}; - -#define KE_REGISTER_IMPL(unique_id, module_name) \ - static void KeInitFunc##unique_id(pybind11::module module_name); \ - static const KernelExplorerInit kKeInitializer##unique_id{KeInitFunc##unique_id}; \ - void KeInitFunc##unique_id(pybind11::module module_name) - -#define KE_REGISTER_(unique_id, module_name) KE_REGISTER_IMPL(unique_id, module_name) -#define KE_REGISTER(module_name) KE_REGISTER_(__COUNTER__, module_name) - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi b/onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi deleted file mode 100644 index 4682f7135d7a3..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi +++ /dev/null @@ -1,20 +0,0 @@ -class DeviceArray: - def __init__(self, ndarray) -> None: ... - def UpdateHostNumpyArray(self) -> None: ... # noqa: N802 - def UpdateDeviceArray(self) -> None: ... # noqa: N802 - -class blas_op: # noqa: N801 - T: int - N: int - -class qkv_format: # noqa: N801 - Q_K_V_BNSH: int - Q_K_V_BSNH: int - QKV_BSN3H: int - Q_KV_BSNH_BSN2H: int - -def is_composable_kernel_available(*args, **kwargs): ... -def is_hipblaslt_available(*args, **kwargs): ... - -def enable_collect_tuning_results(*args, **kwargs): ... -def get_collected_tuning_results(*args, **kwargs): ... diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py deleted file mode 100644 index 01d51099ca577..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py +++ /dev/null @@ -1,215 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import os -from dataclasses import dataclass -from itertools import product - -import kernel_explorer as ke -import numpy as np -import pytest -from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix - -max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", "64")) - - -def dtype_to_suffix(dtype): - return { - "float32": "float", - "float16": "half", - }[dtype] - - -@ke.dispatchable -def _test_batched_gemm( - func, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int, alpha=1.0, beta=0.0 -): - assert dtype in ["float32", "float16"] - - a_shape = (k, m) if transa else (m, k) - b_shape = (n, k) if transb else (k, n) - - np.random.seed(0) - as_ = [(np.random.rand(*a_shape) + 0.5).astype(dtype).astype("float64") for i in range(batch)] - bs = [(np.random.rand(*b_shape) + 0.5).astype(dtype).astype("float64") for i in range(batch)] - intermediate_cs = [matmul(as_[i], bs[i], transa, transb) for i in range(batch)] - if alpha == 1.0 and beta == 0.0: # fast path - ref_cs = intermediate_cs - else: - ref_cs = [alpha * cs + beta * np.ones_like(cs) for cs in intermediate_cs] - - bounds = [get_gemm_bound(dtype, as_[i], bs[i], ref_cs[i], transa, transb, a_b_positive=True) for i in range(batch)] - - as_ = [a.astype(dtype) for a in as_] - bs = [b.astype(dtype) for b in bs] - - my_cs = [np.ones((m, n), dtype=dtype) for i in range(batch)] - dev_as = [ke.DeviceArray(a) for a in as_] - dev_bs = [ke.DeviceArray(b) for b in bs] - dev_cs = [ke.DeviceArray(my_c) for my_c in my_cs] - - opa = ke.blas_op.T if transa else ke.blas_op.N - opb = ke.blas_op.T if transb else ke.blas_op.N - lda = a_shape[1] - ldb = b_shape[1] - ldc = n - my_gemm = func(opa, opb, m, n, k, alpha, dev_as, lda, dev_bs, ldb, beta, dev_cs, ldc, batch) - - failures = {} - print( - f"dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} batch={batch:<3} max bound: {max(bounds)}" - ) - - for impl in my_gemm.ListOps(): - if not my_gemm.SelectOp(impl): - continue - - # Restore C Arrays - for my_c in my_cs: - my_c.fill(1.0) - for dev_c in dev_cs: - dev_c.UpdateDeviceArray() - my_gemm.Run() - for dev_c in dev_cs: - dev_c.UpdateHostNumpyArray() - - for i in range(batch): - try: - np.testing.assert_allclose(my_cs[i], ref_cs[i], rtol=bounds[i]) - except Exception as err: - header = "*" * 30 + impl + "*" * 30 - print(header, bounds[i]) - print(err) - print("*" * len(header)) - failures[impl] = str(err) - - if failures: - raise Exception(failures) - - -dtypes = ["float32", "float16"] -all_transabs = list(product([True, False], repeat=2)) - - -@pytest.mark.parametrize("batch", [1, max_batch_size]) -@pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -def test_rocblas_gemm_all_cases(dtype, transa, transb, m, n, k, batch): - wrapper_name = "RocBlasBatchedGemm_" + dtype_to_suffix(dtype) - _test_batched_gemm(getattr(ke, wrapper_name), dtype, transa, transb, m, n, k, batch) - - -# Tunable is basically wrapped around of rocblas and ck gemm, so no need for full tests -@pytest.mark.parametrize("batch", [1, max_batch_size]) -@pytest.mark.parametrize("m, n, k", get_gemm_bert_sizes(full=False)) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -def test_gemm_tunable_bert_cases(dtype, transa, transb, m, n, k, batch): - wrapper_name = f"BatchedGemmTunable_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" - _test_batched_gemm(getattr(ke, wrapper_name), dtype, transa, transb, m, n, k, batch) - - -@pytest.mark.parametrize("alpha, beta", [(0.5, 0.5)]) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -def test_rocblas_gemm_alpha_beta(dtype, transa, transb, alpha, beta): - wrapper_name = "RocBlasBatchedGemm_" + dtype_to_suffix(dtype) - _test_batched_gemm(getattr(ke, wrapper_name), dtype, transa, transb, 512, 512, 768, 8, alpha=alpha, beta=beta) - - -@pytest.mark.parametrize("alpha, beta", [(0.5, 0.5)]) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -def test_tunable_gemm_alpha_beta(dtype, transa, transb, alpha, beta): - wrapper_name = f"BatchedGemmTunable_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" - _test_batched_gemm(getattr(ke, wrapper_name), dtype, transa, transb, 768, 768, 512, 4, alpha=alpha, beta=beta) - - -@dataclass -class BatchedGemmMetric(ke.ComputeMetric): - transa: bool - transb: bool - m: int - n: int - k: int - batch: int - - def report(self): - common = ( - f"{self.dtype} {transab_to_suffix((self.transa, self.transb))} " - f"m={self.m:<4} n={self.n:<4} k={self.k:<4} batch={self.batch:<3} {self.name}" - ) - if self.duration <= 0: - return "not supported " + common - - return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common - - -@ke.dispatchable(pattern_arg=0) -def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int): - a_shape = (k, m) if transa else (m, k) - b_shape = (n, k) if transb else (k, n) - - np.random.seed(0) - as_ = [(np.random.rand(*a_shape) + 0.5).astype(dtype).astype("float64") for i in range(batch)] - bs = [(np.random.rand(*b_shape) + 0.5).astype(dtype).astype("float64") for i in range(batch)] - - my_cs = [np.zeros((m, n), dtype=dtype) for i in range(batch)] - dev_as = [ke.DeviceArray(a) for a in as_] - dev_bs = [ke.DeviceArray(b) for b in bs] - dev_cs = [ke.DeviceArray(my_c) for my_c in my_cs] - - opa = ke.blas_op.T if transa else ke.blas_op.N - opb = ke.blas_op.T if transb else ke.blas_op.N - lda = a_shape[1] - ldb = b_shape[1] - ldc = n - alpha = 1.0 - beta = 0.0 - my_gemm = f(opa, opb, m, n, k, alpha, dev_as, lda, dev_bs, ldb, beta, dev_cs, ldc, batch) - for impl in my_gemm.ListOps(): - duration_ms = -1 - if my_gemm.SelectOp(impl): - duration_ms = my_gemm.Profile() - flops = batch * m * k * n * 2 - ke.report(BatchedGemmMetric(impl, dtype, duration_ms, flops, transa, transb, m, n, k, batch)) - - -@ke.dispatchable -def profile_with_args(dtype, transa, transb, m, n, k, batch): - dtype_suffix = "_" + dtype_to_suffix(dtype) - transab_suffix = "_" + transab_to_suffix((transa, transb)) - fn_rocblas = getattr(ke, "RocBlasBatchedGemm" + dtype_suffix) - fn_tunable = getattr(ke, "BatchedGemmTunable" + dtype_suffix + transab_suffix) - with ke.benchmark(): - profile_gemm_func(fn_rocblas, dtype, transa, transb, m, n, k, batch) - profile_gemm_func(fn_tunable, dtype, transa, transb, m, n, k, batch) - print() - - -def profile(): - for dtype in dtypes: - for m, n, k in get_gemm_bert_sizes(full=False): - for batch in [1, 32, 64]: - profile_with_args(dtype, False, False, m, n, k, batch) - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("dtype", choices=dtypes) - group.add_argument("transa", choices="NT") - group.add_argument("transb", choices="NT") - group.add_argument("m", type=int) - group.add_argument("n", type=int) - group.add_argument("k", type=int) - group.add_argument("batch", type=int) - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch(args.dtype, args.transa == "T", args.transb == "T", args.m, args.n, args.k, args.batch) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu deleted file mode 100644 index 3504ce1bebe8c..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// This file serve as a simple example for adding a tunable op to onnxruntime. - -#include -#include -#include - -#include - -#include "core/providers/cuda/tunable/cuda_tunable.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" - -namespace py = pybind11; - -namespace onnxruntime { - -// Extend the OpParams so that all specializations have the same parameter passing interface -template -struct DequantizeBnb4Params : cuda::tunable::OpParams { - std::string Signature() const override { return std::to_string(n_); } - - int quant_type_; - T* output_; - const uint8_t* quant_; - const T* absmax_; - T* quant_map_buffer_; - int n_; - int k_; -}; - -template -class DequantizeBnb4 : public IKernelExplorer { - public: - DequantizeBnb4( - int quant_type, - DeviceArray& output, - DeviceArray& quant, - DeviceArray& absmax, - DeviceArray& quant_map_buffer, - int n, int k) { - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.quant_type_ = quant_type; - params_.output_ = static_cast(output.ptr()); - params_.quant_ = static_cast(quant.ptr()); - params_.absmax_ = static_cast(absmax.ptr()); - params_.quant_map_buffer_ = static_cast(quant_map_buffer.ptr()); - params_.n_ = n; - params_.k_ = k; - } - - void Run() override { - ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap( - params_.quant_type_, - params_.quant_map_buffer_, - params_.StreamHandle())); - ORT_THROW_IF_ERROR(contrib::cuda::DequantizeBnb4( - params_.quant_map_buffer_, - params_.output_, - params_.quant_, - params_.absmax_, - 64, - params_.n_ * params_.k_, - params_.StreamHandle())); - } - - private: - // A VectorAddOp is a callable that can process const VectorAddParams* - using ParamsT = DequantizeBnb4Params; - ParamsT params_{}; -}; - -#define REGISTER_OP(name, type) \ - py::class_>(m, #name "_" #type) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run); - -KE_REGISTER(m) { - REGISTER_OP(DequantizeBnb4, half); - REGISTER_OP(DequantizeBnb4, float); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu deleted file mode 100644 index e6dee290a6fc4..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// This file serve as a simple example for adding a tunable op to onnxruntime. - -#include -#include -#include - -#include - -#include "core/providers/cuda/tunable/cuda_tunable.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "contrib_ops/cuda/quantization/dequantize_blockwise.cuh" - -namespace py = pybind11; - -namespace onnxruntime { - -// Extend the OpParams so that all specializations have the same parameter passing interface -template -struct DequantizeInt4Params : cuda::tunable::OpParams { - std::string Signature() const override { return std::to_string(n_); } - - T* output_; - const uint8_t* quant_; - const T* scales_; - const uint8_t* zero_points_; - int n_; - int k_; -}; - -template -class DequantizeInt4 : public IKernelExplorer { - public: - DequantizeInt4(DeviceArray& output, DeviceArray& quant, DeviceArray& scales, int n, int k) { - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.output_ = static_cast(output.ptr()); - params_.quant_ = static_cast(quant.ptr()); - params_.scales_ = static_cast(scales.ptr()); - params_.zero_points_ = nullptr; - params_.n_ = n; - params_.k_ = k; - } - - void Run() override { - ORT_THROW_IF_ERROR(contrib::cuda::Dequantize4Bits( - params_.output_, - params_.quant_, - params_.scales_, - params_.zero_points_, - nullptr, /*reorder_idx*/ - params_.k_, - params_.n_, - 32, - params_.StreamHandle())); - } - - private: - // A VectorAddOp is a callable that can process const VectorAddParams* - using ParamsT = DequantizeInt4Params; - ParamsT params_{}; -}; - -#define REGISTER_OP(name, type) \ - py::class_>(m, #name "_" #type) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run); - -KE_REGISTER(m) { - REGISTER_OP(DequantizeInt4, half); - REGISTER_OP(DequantizeInt4, float); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu deleted file mode 100644 index 8b05b96ec38a9..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// This file serve as a simple example for adding a tunable op to onnxruntime. - -#include -#include - -#include - -#include - -#include "core/providers/cuda/tunable/cuda_tunable.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "core/providers/cuda/cuda_stream_handle.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" -#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh" -#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" - -namespace py = pybind11; - -namespace onnxruntime { - -// Extend the OpParams so that all specializations have the same parameter passing interface -template -struct GemmBenchmarkParams : cuda::tunable::OpParams { - std::string Signature() const override { return std::to_string(n_); } - - T* output_; - const T* a_; - const T* b_; - int m_; - int n_; - int k_; - cublasHandle_t cublas_handle; -}; - -template -class GemmBenchmark : public IKernelExplorer { - public: - GemmBenchmark(DeviceArray& output, DeviceArray& a, DeviceArray& b, int m, int n, int k) { - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.output_ = static_cast(output.ptr()); - params_.a_ = static_cast(a.ptr()); - params_.b_ = static_cast(b.ptr()); - params_.m_ = m; - params_.n_ = n; - params_.k_ = k; - - CUBLAS_CALL_THROW(cublasCreate(&(params_.cublas_handle))); - CUDA_CALL_THROW(cudaGetDeviceProperties(&device_prop_, 0)); - } - - void Run() override { - typedef typename ToCudaType::MappedType CudaT; - CudaT one = ToCudaType::FromFloat(1.0f); - CudaT zero = ToCudaType::FromFloat(0.0f); - - // TF32 is enable by default. To disable TF32, set environment variable NVIDIA_TF32_OVERRIDE = 0 - constexpr bool use_tf32 = true; - CUBLAS_CALL_THROW(cublasGemmHelper( - params_.cublas_handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - params_.n_, params_.m_, params_.k_, - &one, - reinterpret_cast(params_.b_), - params_.n_, - reinterpret_cast(params_.a_), - params_.k_, - &zero, - params_.output_, - params_.n_, - device_prop_, - use_tf32)); - } - - private: - // A VectorAddOp is a callable that can process const VectorAddParams* - using ParamsT = GemmBenchmarkParams; - ParamsT params_{}; - cudaDeviceProp device_prop_; -}; - -#define REGISTER_OP(name, type) \ - py::class_>(m, #name "_" #type) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run); - -KE_REGISTER(m) { - REGISTER_OP(GemmBenchmark, half); - REGISTER_OP(GemmBenchmark, float); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu deleted file mode 100644 index bfff2c7cb0721..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include - -#include - -#include "core/providers/cuda/tunable/cuda_tunable.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" -#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" - -namespace py = pybind11; - -namespace onnxruntime { - -// Extend the OpParams so that all specializations have the same parameter passing interface -template -struct MatrixFloatInt4Params : cuda::tunable::OpParams { - std::string Signature() const override { return std::to_string(n_); } - - T* output_; - const T* a_; - const uint8_t* b_; - const T* scales_; - const uint8_t* zero_points_; - int m_; - int n_; - int k_; -}; - -template -class MatrixFloatInt4 : public IKernelExplorer { - public: - MatrixFloatInt4(DeviceArray& output, - DeviceArray& a, - DeviceArray& b, - DeviceArray& scales, - int m, int n, int k) { - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.output_ = static_cast(output.ptr()); - params_.a_ = static_cast(a.ptr()); - params_.b_ = static_cast(b.ptr()); - params_.scales_ = static_cast(scales.ptr()); - params_.zero_points_ = nullptr; - params_.m_ = m; - params_.n_ = n; - params_.k_ = k; - - CUDA_CALL_THROW(cudaGetDeviceProperties(&device_prop_, 0)); - } - - MatrixFloatInt4(DeviceArray& output, - DeviceArray& a, - DeviceArray& b, - DeviceArray& scales, - DeviceArray& zeropoints, - int m, int n, int k) : MatrixFloatInt4(output, a, b, scales, m, n, k) { - params_.zero_points_ = static_cast(zeropoints.ptr()); - } - - void Run() override { - contrib::cuda::TryMatMul4Bits( - params_.output_, - params_.a_, - params_.b_, - params_.scales_, - params_.zero_points_, - params_.m_, - params_.n_, - params_.k_, - 32, - static_cast(device_prop_.sharedMemPerBlock), - params_.StreamHandle()); - } - - private: - // A VectorAddOp is a callable that can process const VectorAddParams* - using ParamsT = MatrixFloatInt4Params; - ParamsT params_{}; - cudaDeviceProp device_prop_; -}; - -#define REGISTER_OP(name, type) \ - py::class_>(m, #name "_" #type) \ - .def(py::init()) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run); - -KE_REGISTER(m) { - REGISTER_OP(MatrixFloatInt4, half); - REGISTER_OP(MatrixFloatInt4, float); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu deleted file mode 100644 index e4cd83565357a..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// This file serve as a simple example for adding a tunable op to onnxruntime. - -#include -#include -#include - -#include - -#include "core/providers/cuda/tunable/cuda_tunable.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" -#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh" -#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" -#include "contrib_ops/cuda/quantization/matmul_bnb4.cuh" - -namespace py = pybind11; - -namespace onnxruntime { - -// Extend the OpParams so that all specializations have the same parameter passing interface -template -struct MatrixFloatBnb4Params : cuda::tunable::OpParams { - std::string Signature() const override { return std::to_string(n_); } - - int quant_type_; - T* output_; - const T* a_; - const uint8_t* b_; - const T* absmax_; - T* quant_map_buffer_; - int m_; - int n_; - int k_; -}; - -template -class MatrixFloatBnb4 : public IKernelExplorer { - public: - MatrixFloatBnb4(DeviceArray& output, - DeviceArray& a, - DeviceArray& b, - DeviceArray& absmax, - DeviceArray& quant_map_buffer, - int quant_type, int m, int n, int k) { - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.output_ = static_cast(output.ptr()); - params_.a_ = static_cast(a.ptr()); - params_.b_ = static_cast(b.ptr()); - params_.absmax_ = static_cast(absmax.ptr()); - params_.quant_map_buffer_ = static_cast(quant_map_buffer.ptr()); - params_.quant_type_ = quant_type; - params_.m_ = m; - params_.n_ = n; - params_.k_ = k; - } - - void Run() override { - ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap( - params_.quant_type_, - params_.quant_map_buffer_, - params_.StreamHandle())); - contrib::cuda::TryMatMulBnb4( - params_.quant_map_buffer_, - params_.output_, - params_.a_, - params_.b_, - params_.absmax_, - params_.m_, - params_.n_, - params_.k_, - 64, - params_.StreamHandle()); - } - - private: - // A VectorAddOp is a callable that can process const VectorAddParams* - using ParamsT = MatrixFloatBnb4Params; - ParamsT params_{}; -}; - -#define REGISTER_OP(name, type) \ - py::class_>(m, #name "_" #type) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run); - -KE_REGISTER(m) { - REGISTER_OP(MatrixFloatBnb4, half); - REGISTER_OP(MatrixFloatBnb4, float); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py deleted file mode 100644 index 140151aadcc0f..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py +++ /dev/null @@ -1,92 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import sys -from dataclasses import dataclass - -import kernel_explorer as ke -import numpy as np -from utils import dtype_to_bytes - - -def dtype_to_funcs(dtype): - type_map = { - "float16": list(filter(lambda x: "DequantizeBnb4_half" in x, dir(ke))), - "float32": list(filter(lambda x: "DequantizeBnb4_float" in x, dir(ke))), - } - return type_map[dtype] - - -quant_enums = {"FP4": 0, "NF4": 1} - - -dtypes = ["float16", "float32"] -quant_types = ["FP4", "NF4"] - - -@dataclass -class DequantizeBnb4Metric(ke.BandwidthMetric): - quant_type: str - n: int - k: int - - def report(self): - return ( - f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s" - f" {self.quant_type} {self.dtype} n={self.n} k={self.k} {self.name}" - ) - - -def profile_dequantize_int4_func(qt, n, k, dtype, func): - np.random.seed(0) - block_size = 64 - numel = n * k - output = np.random.rand(n, k).astype(dtype) - quant = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8") - absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype) - quant_map_buffer = np.zeros(16).astype(dtype) - - output_d = ke.DeviceArray(output) - quant_d = ke.DeviceArray(quant) - absmax_d = ke.DeviceArray(absmax) - quant_map_buffer_d = ke.DeviceArray(quant_map_buffer) - f = getattr(ke, func) - my_op = f(quant_enums[qt], output_d, quant_d, absmax_d, quant_map_buffer_d, n, k) - duration_ms = my_op.Profile() - total_bytes = numel / 2 + (numel + numel / block_size) * dtype_to_bytes(dtype) - - ke.report(DequantizeBnb4Metric(func, dtype, duration_ms, total_bytes, qt, n, k)) - - -def profile_with_args(qt, n, k, dtype, sort): - with ke.benchmark(sort): - for func in dtype_to_funcs(dtype): - profile_dequantize_int4_func(qt, n, k, dtype, func) - - -def profile(): - for qt in quant_types: - for dt in dtypes: - for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): - profile_with_args(qt, n, k, dt, True) - print() - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") - group.add_argument("n", type=int) - group.add_argument("k", type=int) - group.add_argument("quant_type", choices=quant_types) - group.add_argument("dtype", choices=dtypes) - group.add_argument("--sort", action="store_true") - - if len(sys.argv) == 1: - profile() - else: - args = parser.parse_args() - profile_with_args(args.quant_type, args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py deleted file mode 100644 index ba049fad773aa..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py +++ /dev/null @@ -1,76 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from dataclasses import dataclass - -import kernel_explorer as ke -import numpy as np -from utils import dtype_to_bytes - - -def dtype_to_funcs(dtype): - type_map = { - "float16": list(filter(lambda x: "DequantizeInt4_half" in x, dir(ke))), - "float32": list(filter(lambda x: "DequantizeInt4_float" in x, dir(ke))), - } - return type_map[dtype] - - -dtypes = ["float16", "float32"] - - -@dataclass -class DequantizeInt4Metric(ke.BandwidthMetric): - n: int - k: int - - def report(self): - return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} n={self.n} k={self.k} {self.name}" - - -@ke.dispatchable(pattern_arg=3) -def profile_dequantize_int4_func(n, k, dtype, func): - np.random.seed(0) - output = np.random.rand(n, k).astype(dtype) - quant = np.random.randint(low=0, high=127, size=(n, (k + 31) // 32, 16)).astype("uint8") - scales = np.random.rand(n, (k + 31) // 32).astype(dtype) - - output_d = ke.DeviceArray(output) - quant_d = ke.DeviceArray(quant) - scales_d = ke.DeviceArray(scales) - f = getattr(ke, func) - my_op = f(output_d, quant_d, scales_d, n, k) - duration_ms = my_op.Profile() - total_bytes = (n * k) / 2 + (n * k + n * k / 32) * dtype_to_bytes(dtype) - - ke.report(DequantizeInt4Metric(func, dtype, duration_ms, total_bytes, n, k)) - - -@ke.dispatchable -def profile_with_args(n, k, dtype): - with ke.benchmark(): - for func in dtype_to_funcs(dtype): - profile_dequantize_int4_func(n, k, dtype, func) - - -def profile(): - for dt in dtypes: - for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): - profile_with_args(n, k, dt) - print() - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("n", type=int) - group.add_argument("k", type=int) - group.add_argument("dtype", choices=dtypes) - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch(args.n, args.k, args.dtype) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/elementwise_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/elementwise_test.py deleted file mode 100644 index 425d8843814c3..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/elementwise_test.py +++ /dev/null @@ -1,142 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import re -from dataclasses import dataclass -from itertools import product - -import kernel_explorer as ke -import numpy as np -import pytest -from utils import dtype_to_bytes, fast_gelu, gelu, relu - - -def get_bert_sizes(): - batch_sizes = [1] - seq_lens = [384] - hidden_sizes = [1024] - return product(batch_sizes, seq_lens, hidden_sizes) - - -def dtype_to_funcs(fn_name, dtype): - type_map = { - "float16": list(filter(lambda x: re.match(f"{fn_name}.*_half.*", x), dir(ke))), - "float32": list(filter(lambda x: re.match(f"{fn_name}.*_float.*", x), dir(ke))), - "float64": list(filter(lambda x: re.match(f"{fn_name}.*_double.*", x), dir(ke))), - } - return type_map[dtype] - - -def fn_name_to_ref_impl(fn_name): - return { - "FastGeLU": fast_gelu, - "GeLU": gelu, - "ReLU": relu, - }[fn_name] - - -def run_elementwise(x_size, bias_size, fn_name, dtype, func): - np.random.seed(0) - x = np.random.rand(*x_size).astype(dtype) - bias = np.random.rand(bias_size).astype(dtype) - y = np.random.rand(*x_size).astype(dtype) - - x_d = ke.DeviceArray(x) - bias_d = ke.DeviceArray(bias) - y_d = ke.DeviceArray(y) - f = getattr(ke, func) - my_op = f(x_d, bias_d, y_d, x.size, bias.size) - if my_op.IsSupported(): - my_op.Run() - y_d.UpdateHostNumpyArray() - - ref_fn = fn_name_to_ref_impl(fn_name) - y_ref = ref_fn(x, bias) - np.testing.assert_allclose(y_ref, y, atol=1e-3, rtol=1e-04) - - -test_cases = [((2, 16), 16), ((1, 2, 768), 768), ((1, 2, 1024), 1024), ((1, 2, 1027), 1027), ((1, 3, 3), 3)] -fn_names = ["FastGeLU", "GeLU", "ReLU"] -dtypes = ["float16", "float32"] - - -@pytest.mark.parametrize("x_size, bias_size", test_cases) -@pytest.mark.parametrize("dtype", dtypes) -def test_fast_gelu(x_size, bias_size, dtype): - for f in dtype_to_funcs("FastGeLU", dtype): - run_elementwise(x_size, bias_size, "FastGeLU", dtype, f) - - -@pytest.mark.parametrize("fn_name", fn_names) -@pytest.mark.parametrize("dtype", dtypes) -def test_elementwise_fns(fn_name, dtype): - for f in dtype_to_funcs(fn_name, dtype): - run_elementwise((1, 2, 768), 768, fn_name, dtype, f) - - -@dataclass -class ElementwiseMetric(ke.BandwidthMetric): - batch_size: int - seq_len: int - hidden_size: int - - def report(self): - common = f"{self.dtype} batch_size={self.batch_size:<4} seq_len={self.seq_len:<4} hidden_size={self.hidden_size:<4} {self.name}" - if self.duration > 0: - return f"{self.duration:>6.2f} us {self.gbps:>5.2f} GB/s " + common - return "not supported " + common - - -@ke.dispatchable(pattern_arg=4) -def profile_elementwise_func(batch_size, seq_len, hidden_size, dtype, func): - x_size = [batch_size, seq_len, hidden_size] - bias_size = hidden_size - np.random.seed(0) - x = np.random.rand(*x_size).astype(dtype) - bias = np.random.rand(bias_size).astype(dtype) - y = np.random.rand(*x_size).astype(dtype) - - x_d = ke.DeviceArray(x) - bias_d = ke.DeviceArray(bias) - y_d = ke.DeviceArray(y) - f = getattr(ke, func) - my_op = f(x_d, bias_d, y_d, x.size, bias.size) - - duration_ms = -1 - if my_op.IsSupported(): - duration_ms = my_op.Profile() - total_bytes = (x.size * 2 + bias.size) * dtype_to_bytes(dtype) - - ke.report(ElementwiseMetric(func, dtype, duration_ms, total_bytes, batch_size, seq_len, hidden_size)) - - -@ke.dispatchable -def profile_with_args(batch_size, seq_len, hidden_size, fn_name, dtype): - with ke.benchmark(): - for func in dtype_to_funcs(fn_name, dtype): - profile_elementwise_func(batch_size, seq_len, hidden_size, dtype, func) - - -def profile(): - for dtype in dtypes: - for bert_size in get_bert_sizes(): - profile_with_args(*bert_size, "FastGeLU", dtype) - print() - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("batch_size", type=int) - group.add_argument("seq_len", type=int) - group.add_argument("hidden_size", type=int) - group.add_argument("fn_name", choices=fn_names) - group.add_argument("dtype", choices=dtypes) - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch(args.batch_size, args.seq_len, args.hidden_size, args.fn_name, args.dtype) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py deleted file mode 100644 index 8ee9c6bc0f040..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py +++ /dev/null @@ -1,195 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from dataclasses import dataclass -from itertools import product - -import kernel_explorer as ke -import numpy as np -import pytest -from utils import ( - dtype_to_suffix, - fast_gelu, - get_gemm_basic_sizes, - get_gemm_bert_sizes, - get_gemm_bound, - matmul, - transab_to_suffix, -) - - -# TODO The test method needs update. -def _test_gemmfastgelu(my_func, dtype: str, m: int, n: int, k: int, transa=False, transb=False): - assert dtype in ["float16", "float32"] - - a_shape = (k, m) if transa else (m, k) - b_shape = (n, k) if transb else (k, n) - - np.random.seed(0) - a = (np.random.rand(*a_shape)).astype(dtype).astype("float64") - b = (np.random.rand(*b_shape)).astype(dtype).astype("float64") - bias = (np.random.rand(n)).astype(dtype) - temp_c = matmul(a, b, transa, transb) - - bound = get_gemm_bound(dtype, a, b, temp_c, transa, transb, a_b_positive=True) - - temp_c = temp_c.astype(dtype) - ref_c = fast_gelu(temp_c, bias) - - a = a.astype(dtype) - b = b.astype(dtype) - - my_c = np.zeros((m, n), dtype=dtype) - dev_a = ke.DeviceArray(a) - dev_b = ke.DeviceArray(b) - dev_bias = ke.DeviceArray(bias) - dev_c = ke.DeviceArray(my_c) - - opa = ke.blas_op.T if transa else ke.blas_op.N - opb = ke.blas_op.T if transb else ke.blas_op.N - lda = a_shape[1] - ldb = b_shape[1] - alpha = 1.0 - beta = 0.0 - my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n) - - print(f"dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} bound: {max(bound, 1e-2)}") - - for impl in my_op.ListOps(): - if not my_op.SelectOp(impl): - continue - - my_op.Run() - dev_c.UpdateHostNumpyArray() - - np.testing.assert_allclose(my_c, ref_c, rtol=max(bound, 1e-2)) - - -dtypes = ["float16", "float32"] -all_transabs = list(product([True, False], repeat=2)) - - -@pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) -@pytest.mark.parametrize("transab", all_transabs) -def test_gemmfastgelu_unfused_bert_cases(dtype, size, transab): - _test_gemmfastgelu(getattr(ke, "GemmFastGeluUnfused_" + dtype_to_suffix(dtype)), dtype, *size, *transab) - - -@pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) -@pytest.mark.parametrize("transab", all_transabs) -def test_gemmfastgelu_tunable_bert_cases(dtype, size, transab): - wrapper_name = f"GemmFastGeluTunable_{dtype_to_suffix(dtype)}_{transab_to_suffix(transab)}" - _test_gemmfastgelu(getattr(ke, wrapper_name), dtype, *size, *transab) - - -@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") -@pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) -@pytest.mark.parametrize("transab", all_transabs) -def test_gemmfastgelu_ck_bert_cases(dtype, size, transab): - wrapper_name = f"CKGemmFastGelu_{dtype_to_suffix(dtype)}_{transab_to_suffix(transab)}" - _test_gemmfastgelu(getattr(ke, wrapper_name), dtype, *size, *transab) - - -@pytest.mark.skipif(not ke.is_hipblaslt_available(), reason="hipblaslt is not available") -@pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) -@pytest.mark.parametrize("transab", all_transabs) -def test_gemmfastgelu_hipblaslt_bert_cases(dtype, size, transab): - _test_gemmfastgelu(getattr(ke, "GemmFastGeluHipBlasLt_" + dtype_to_suffix(dtype)), dtype, *size, *transab) - - -@dataclass -class GemmFastGeluMetric(ke.ComputeMetric): - transa: bool - transb: bool - m: int - n: int - k: int - - def report(self): - transab = transab_to_suffix((self.transa, self.transb)) - common = f"{self.dtype} m={self.m:<4} n={self.n:<4} k={self.k:<4} {transab}, {self.name}" - if self.duration <= 0: - return "not supported " + common - return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common - - -@ke.dispatchable(pattern_arg=0) -def profile_gemmfastgelu_func(my_func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool): - a_shape = (k, m) if transa else (m, k) - b_shape = (n, k) if transb else (k, n) - - np.random.seed(0) - a = (np.random.rand(*a_shape) * 2 - 1).astype(dtype) - b = (np.random.rand(*b_shape) * 2 - 1).astype(dtype) - my_c = np.zeros((m, n), dtype=dtype) - bias = np.random.rand(n).astype(dtype) - - dev_a = ke.DeviceArray(a) - dev_b = ke.DeviceArray(b) - dev_bias = ke.DeviceArray(bias) - dev_c = ke.DeviceArray(my_c) - - opa = ke.blas_op.T if transa else ke.blas_op.N - opb = ke.blas_op.T if transb else ke.blas_op.N - lda = a_shape[1] - ldb = b_shape[1] - alpha = 1.0 - beta = 0.0 - my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n) - - for impl in my_op.ListOps(): - duration_ms = -1 - if my_op.SelectOp(impl): - duration_ms = my_op.Profile() - # only counts gemm tflops because fastgelu is low order term (7 * n). - floating_point_operations = m * k * n * 2 - - ke.report(GemmFastGeluMetric(impl, dtype, duration_ms, floating_point_operations, transa, transb, m, n, k)) - - -@ke.dispatchable -def profile_with_args(transa, transb, dtype, m, n, k): - dtype_suffix = "_" + dtype_to_suffix(dtype) - transab_suffix = "_" + transab_to_suffix((transa, transb)) - with ke.benchmark(): - profile_gemmfastgelu_func(getattr(ke, "GemmFastGeluUnfused" + dtype_suffix), dtype, m, n, k, transa, transb) - profile_gemmfastgelu_func( - getattr(ke, "CKGemmFastGelu" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb - ) - profile_gemmfastgelu_func( - getattr(ke, "GemmFastGeluTunable" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb - ) - if ke.is_hipblaslt_available(): - profile_gemmfastgelu_func( - getattr(ke, "GemmFastGeluHipBlasLt" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb - ) - - -def profile(): - for dtype in dtypes: - for m, n, k in get_gemm_bert_sizes(full=True): - profile_with_args(False, False, dtype, m, n, k) - print() - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("transa", choices="NT") - group.add_argument("transb", choices="NT") - group.add_argument("dtype", choices=dtypes) - group.add_argument("m", type=int) - group.add_argument("n", type=int) - group.add_argument("k", type=int) - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py deleted file mode 100644 index 76d0b2a3138bc..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py +++ /dev/null @@ -1,305 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from dataclasses import dataclass - -import kernel_explorer as ke -import numpy as np -import pytest -from ml_dtypes import finfo, float8_e4m3fn, float8_e4m3fnuz -from utils import dtype_to_bytes, dtype_to_suffix, get_gemm_bert_sizes, matmul, transab_to_suffix - - -def create_device_array(a): - ptr = a.__array_interface__["data"][0] - size = a.size - itemsize = finfo(a.dtype).bits // 8 - return ke.DeviceArray(ptr, size, itemsize) - - -def compute_scaling_factor(a: np.ndarray, fp8_max: float, margin: int) -> np.ndarray: - amax = np.abs(a).max() - scale = (fp8_max - margin) / amax # fallback scale - exp = np.floor(np.log2(fp8_max / amax)) - margin - sf = np.round(np.power(2, np.abs(exp))) - sf = np.where(amax > 0.0, sf, scale) - sf = np.where(np.isfinite(amax), sf, scale) - sf = np.where(exp < 0, 1 / sf, sf) - - return sf - - -def cast_and_scale(a, dtype: str): - if dtype == "float16": - return a.astype(dtype), 1.0 - elif np.dtype(dtype) in (float8_e4m3fn, float8_e4m3fnuz): - t = globals()[dtype] - sf = compute_scaling_factor(a, fp8_max=finfo(t).max, margin=4) - return (a * sf).astype(t), sf - else: - raise ValueError(dtype) - - -@ke.dispatchable(pattern_arg=0) -def _test_gemm( - func, dta: str, dtb: str, dtc: str, transa: bool, transb: bool, m: int, n: int, k: int, alpha=1.0, beta=0.0 -): - assert beta == 0.0, "beta is not supported" - assert dta in ["float16", "float8_e4m3fn", "float8_e4m3fnuz"] - assert dtb in ["float16", "float8_e4m3fn", "float8_e4m3fnuz"] - assert dtc in ["float16"] - - a_shape = (k, m) if transa else (m, k) - b_shape = (n, k) if transb else (k, n) - - np.random.seed(0) - - a, scale_a = cast_and_scale(np.random.rand(*a_shape), dta) - b, scale_b = cast_and_scale(np.random.rand(*b_shape), dtb) - scale_c = float("nan") - - inv_scale_a = np.array(1 / scale_a).astype("float32") - inv_scale_b = np.array(1 / scale_b).astype("float32") - inv_scale_c = np.array(1 / scale_c).astype("float32") - - ref_c = matmul(a * inv_scale_a, b * inv_scale_b, transa, transb) - if alpha != 1.0: - ref_c *= alpha - - my_c = np.ones((m, n), dtype=dtc) - dev_a = create_device_array(a) - dev_b = create_device_array(b) - dev_c = create_device_array(my_c) - dev_inv_scale_a = create_device_array(inv_scale_a) - dev_inv_scale_b = create_device_array(inv_scale_b) - dev_inv_scale_c = create_device_array(inv_scale_c) - - opa = ke.blas_op.T if transa else ke.blas_op.N - opb = ke.blas_op.T if transb else ke.blas_op.N - lda = a_shape[1] - ldb = b_shape[1] - my_gemm = func( - opa, - opb, - m, - n, - k, - alpha, - dev_a, - lda, - dev_inv_scale_a, - dev_b, - ldb, - dev_inv_scale_b, - beta, - dev_c, - n, - dev_inv_scale_c, - ) - - failures = {} - - # TODO: how to derive the bound for fp8? - atol = 0.01 - rtol = 0.005 - print(f"atol={atol} rtol={rtol}") # print for pytest -s -v - - for impl in my_gemm.ListOps(): - if not my_gemm.SelectOp(impl): - continue - # Restore C Array - my_c.fill(1.0) - dev_c.UpdateDeviceArray() - my_gemm.Run() - dev_c.UpdateHostNumpyArray() - - try: - np.testing.assert_allclose(my_c, ref_c, atol=atol, rtol=rtol) - except Exception as err: - header = "*" * 30 + impl + "*" * 30 - print(header) - print(err) - print("*" * len(header)) - failures[impl] = str(err) - - if failures: - raise Exception(failures) - - -dtypes = [ - ("float8_e4m3fn", "float16", "float16"), - ("float8_e4m3fnuz", "float16", "float16"), - ("float16", "float8_e4m3fn", "float16"), - ("float16", "float8_e4m3fnuz", "float16"), -] -all_transabs = [(False, False), (False, True)] - - -@pytest.mark.skipif(not ke.is_float8_available(), reason="float8 is not enabled") -@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") -@pytest.mark.parametrize( - "m, n, k", - [ - (1, 768, 768), - (768, 768, 768), - (1, 8192, 28672), - (1, 28672, 8192), - (1, 8192, 8192), - (128, 8192, 28672), - (128, 28672, 8192), - (128, 8192, 8192), - ], -) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dta, dtb, dtc", dtypes) -@ke.dispatchable -def test_ck_gemm(dta, dtb, dtc, transa, transb, m, n, k): - if dtb == "float16" and transb: - pytest.skip("Only supports transb when b is fp8") - wrapper_name = f"GemmFloat8CK_{dtype_to_suffix(dta)}_{dtype_to_suffix(dtb)}_{dtype_to_suffix(dtc)}_{transab_to_suffix((transa, transb))}" - _test_gemm(getattr(ke, wrapper_name), dta, dtb, dtc, transa, transb, m, n, k) - - -@pytest.mark.skipif(not ke.is_float8_available(), reason="float8 is not enabled") -@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") -@pytest.mark.parametrize("alpha, beta", [(1.5, 0.0), [2.0, 0.0]]) -@pytest.mark.parametrize("m, n, k", [(768, 768, 768)]) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dta, dtb, dtc", dtypes) -def test_ck_gemm_alpha_beta(dta, dtb, dtc, transa, transb, m, n, k, alpha, beta): - if dtb == "float16" and transb: - pytest.skip("Only supports transb when b is fp8") - wrapper_name = f"GemmFloat8CK_{dtype_to_suffix(dta)}_{dtype_to_suffix(dtb)}_{dtype_to_suffix(dtc)}_{transab_to_suffix((transa, transb))}" - _test_gemm(getattr(ke, wrapper_name), dta, dtb, dtc, transa, transb, m, n, k, alpha, beta) - - -@pytest.mark.skipif(not ke.is_float8_available(), reason="float8 is not enabled") -@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") -@pytest.mark.parametrize("alpha, beta", [(1.5, 0.0), [2.0, 0.0]]) -@pytest.mark.parametrize("m, n, k", [(256, 256, 256)]) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dta, dtb, dtc", dtypes) -def test_tunable_gemm(dta, dtb, dtc, transa, transb, m, n, k, alpha, beta): - if dtb == "float16" and transb: - pytest.skip("Only supports transb when b is fp8") - wrapper_name = f"GemmFloat8Tunable_{dtype_to_suffix(dta)}_{dtype_to_suffix(dtb)}_{dtype_to_suffix(dtc)}_{transab_to_suffix((transa, transb))}" - _test_gemm(getattr(ke, wrapper_name), dta, dtb, dtc, transa, transb, m, n, k, alpha, beta) - - -@dataclass -class GemmMetric(ke.BandwidthMetric, ke.ComputeMetric): - transa: bool - transb: bool - m: int - n: int - k: int - - def report(self): - common = ( - f"{self.dtype} {transab_to_suffix((self.transa, self.transb))} " - f"m={self.m:<4} n={self.n:<4} k={self.k:<4} {self.name}" - ) - if self.duration <= 0: - return "not supported " + common - - return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops {self.gbps:5.2f} GB/s " + common - - -@ke.dispatchable(pattern_arg=0) -def profile_gemm_func( - func, dta: str, dtb: str, dtc: str, transa: bool, transb: bool, m: int, n: int, k: int, alpha=1.0, beta=0.0 -): - assert beta == 0.0, "beta is not supported" - a_shape = (k, m) if transa else (m, k) - b_shape = (n, k) if transb else (k, n) - - np.random.seed(0) - a, scale_a = cast_and_scale(np.random.rand(*a_shape) + 0.1, dta) - b, scale_b = cast_and_scale(np.random.rand(*b_shape) + 0.1, dtb) - scale_c = 1.0 - - inv_scale_a = np.array(1 / scale_a).astype("float32") - inv_scale_b = np.array(1 / scale_b).astype("float32") - inv_scale_c = np.array(1 / scale_c).astype("float32") - - my_c = np.ones((m, n), dtype=dtc) - - dev_a = create_device_array(a) - dev_b = create_device_array(b) - dev_c = create_device_array(my_c) - dev_inv_scale_a = create_device_array(inv_scale_a) - dev_inv_scale_b = create_device_array(inv_scale_b) - dev_inv_scale_c = create_device_array(inv_scale_c) - - opa = ke.blas_op.T if transa else ke.blas_op.N - opb = ke.blas_op.T if transb else ke.blas_op.N - lda = a_shape[1] - ldb = b_shape[1] - my_gemm = func( - opa, - opb, - m, - n, - k, - alpha, - dev_a, - lda, - dev_inv_scale_a, - dev_b, - ldb, - dev_inv_scale_b, - beta, - dev_c, - n, - dev_inv_scale_c, - ) - - for impl in my_gemm.ListOps(): - duration_ms = -1 - if my_gemm.SelectOp(impl): - duration_ms = my_gemm.Profile() - FLOPs = m * k * n * 2 # noqa: N806 - total_bytes = m * k * dtype_to_bytes(dta) + k * n * dtype_to_bytes(dtb) + m * n * dtype_to_bytes(dtc) - - ke.report(GemmMetric(impl, f"{dta}_{dtb}_{dtc}", duration_ms, FLOPs, total_bytes, transa, transb, m, n, k)) - - -@ke.dispatchable -def profile_with_args(dta, dtb, dtc, transa, transb, m, n, k): - dtype_suffix = "_" + dtype_to_suffix(dta) + "_" + dtype_to_suffix(dtb) + "_" + dtype_to_suffix(dtc) - transab_suffix = "_" + transab_to_suffix((transa, transb)) - with ke.benchmark(): - profile_gemm_func( - getattr(ke, "GemmFloat8CK" + dtype_suffix + transab_suffix), dta, dtb, dtc, transa, transb, m, n, k - ) - profile_gemm_func( - getattr(ke, "GemmFloat8Tunable" + dtype_suffix + transab_suffix), dta, dtb, dtc, transa, transb, m, n, k - ) - print() - - -def profile(): - for dta, dtb, dtc in dtypes: - for m, n, k in get_gemm_bert_sizes(full=True): - profile_with_args(dta, dtb, dtc, False, False, m, n, k) - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("dta", choices=["float8_e4m3fn", "float8_e4m3fnuz", "float16"]) - group.add_argument("dtb", choices=["float8_e4m3fn", "float8_e4m3fnuz", "float16"]) - group.add_argument("dtc", choices=["float8_e4m3fn", "float8_e4m3fnuz", "float16"]) - group.add_argument("transa", choices="NT") - group.add_argument("transb", choices="NT") - group.add_argument("m", type=int) - group.add_argument("n", type=int) - group.add_argument("k", type=int) - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch(args.dta, args.dtb, args.dtc, args.transa == "T", args.transb == "T", args.m, args.n, args.k) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py deleted file mode 100644 index aedcc0c5b71ce..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py +++ /dev/null @@ -1,613 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - - -import os -from dataclasses import dataclass -from itertools import product - -import kernel_explorer as ke -import numpy as np -import pytest -from utils import dtype_to_suffix, matmul, softmax - -max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", "64")) - - -def multinormal_distribution(num_distribution, num_element_per_dist): - arrays = [] - for _ in range(num_distribution): - mean = np.random.rand() - 0.5 - std = np.random.rand() + 0.5 - arrays.append(np.random.normal(mean, std, (num_element_per_dist,))) - return np.array(arrays) - - -def get_ck_binding_name(dtype, biased: bool, masked: bool): - dtype_suffix = "_" + dtype_to_suffix(dtype) - ck_suffix = "" - if biased: - ck_suffix += "Biased" - if masked: - ck_suffix += "Masked" - ck_suffix += dtype_suffix - return "GemmSoftmaxGemmPermuteCK" + ck_suffix - - -dtypes = ["float16"] -batches = [1, max_batch_size] -seqlens = [128, 512] -total_seqlens = [128, 512] -num_heads = [8, 12] -head_sizes = [64] -biaseds = [False, True] -causals = [False] -mask_dims = [0, 2, 3, 4] - - -def get_biased_id(biased): - return "biased" if biased else "nobias" - - -def get_mask_dim_id(dim): - if dim == 0: - return "nomask" - return f"mask_{dim}d" - - -def maybe_pack_q_k_v_bnsh_for_device_on_host(q, k, v, dtype, qkv_format): - q = q.astype(dtype) - k = k.astype(dtype) - v = v.astype(dtype) - if qkv_format == ke.qkv_format.Q_K_V_BNSH: - return q, k, v - - # BNSH to BSNH - q = np.swapaxes(q, 2, 1) - k = np.swapaxes(k, 2, 1) - v = np.swapaxes(v, 2, 1) - - if qkv_format == ke.qkv_format.Q_K_V_BSNH: - return np.ascontiguousarray(q), np.ascontiguousarray(k), np.ascontiguousarray(v) - - if qkv_format == ke.qkv_format.QKV_BSN3H: - return np.ascontiguousarray(np.stack([q, k, v], axis=-2)), None, None - - if qkv_format == ke.qkv_format.Q_KV_BSNH_BSN2H: - return np.ascontiguousarray(q), np.ascontiguousarray(np.stack([k, v], axis=-2)), None - - raise NotImplementedError - - -def _make_causal_mask( - seqence_length, - total_sequence_length, - dtype: np.dtype, -): - """ - Make causal mask used for Attention with attribute unidirectional == 1. - The mask is a upper triangular matrix with shape [sequence_length, total_sequence_length]. - Putting a 1 indicates that the token at this position should be masked. - For Example: - sequence_length = 5, total_sequence_length = 5, - mask: [[0. 1. 1. 1. 1.] - [0. 0. 1. 1. 1.] - [0. 0. 0. 1. 1.] - [0. 0. 0. 0. 1.] - [0. 0. 0. 0. 0.]] - seqence_length = 5, total_seqence_length = 3, - mask: [[1. 1. 1.] - [1. 1. 1.] - [0. 1. 1.] - [0. 0. 1.] - [0. 0. 0.]] - seqence_length = 5, total_seqence_length = 7, - mask: [[0. 0. 0. 1. 1. 1. 1.] - [0. 0. 0. 0. 1. 1. 1.] - [0. 0. 0. 0. 0. 1. 1.] - [0. 0. 0. 0. 0. 0. 1.] - [0. 0. 0. 0. 0. 0. 0.]] - """ - mask = np.full((seqence_length, seqence_length), 1) - mask_cond = np.arange(mask.shape[-1]) - mask = np.where(mask_cond < (mask_cond + 1).reshape(mask.shape[-1], 1), 0, mask) - - mask = mask.astype(dtype) - - if total_sequence_length - seqence_length > 0: - mask = np.concatenate( - [np.zeros((seqence_length, total_sequence_length - seqence_length), dtype=dtype), mask], axis=-1 - ) - - if total_sequence_length - seqence_length < 0: - mask = mask[:, -total_sequence_length:] - - correct_mask = np.full((seqence_length, total_sequence_length), 1) - for i in range(seqence_length): - correct_mask[i][:] = sum(mask[i]) != total_sequence_length - return mask, correct_mask - - -def _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format -): - v_head_size = head_size - q_shape = [batch, num_heads, seqlen, head_size] - k_shape = [batch, num_heads, total_seqlen, head_size] - v_shape = [batch, num_heads, total_seqlen, v_head_size] - out_shape = [batch, seqlen, num_heads, head_size] - - attn_bias = None - bias_shape = [batch, num_heads, seqlen, total_seqlen] if biased else None - - attn_mask = None - mask_shape = None - mask_shape_broadcasted = None - max_seqlen = None - if mask_dim != 0: - if mask_dim == 2: - mask_shape = [batch, total_seqlen] - mask_shape_broadcasted = [batch, 1, 1, total_seqlen] - elif mask_dim == 3: - mask_shape = [batch, seqlen, total_seqlen] - mask_shape_broadcasted = [batch, 1, seqlen, total_seqlen] - elif mask_dim == 4: - max_seqlen = ((seqlen - 1) // 1024 + 1) * 1024 # round up to multiple of 1024 - mask_shape = [batch, 1, max_seqlen, max_seqlen] - else: - raise ValueError - - np.random.seed(42) - q = multinormal_distribution(np.prod(q_shape[:-1]), q_shape[-1]).reshape(q_shape).astype(np.float64) - k = multinormal_distribution(np.prod(k_shape[:-1]), k_shape[-1]).reshape(k_shape).astype(np.float64) - v = multinormal_distribution(np.prod(v_shape[:-1]), v_shape[-1]).reshape(v_shape).astype(np.float64) - if bias_shape is not None: - attn_bias = np.random.uniform(-0.5, 0.5, size=bias_shape) - if mask_shape is not None: - attn_mask = (np.random.randint(0, 100, size=mask_shape) < 95).astype(np.int32) - - pre_softmax_attn_scores = matmul(q, np.swapaxes(k, 2, 3)) - pre_softmax_attn_scores = pre_softmax_attn_scores * scale - if attn_bias is not None: - pre_softmax_attn_scores = pre_softmax_attn_scores + attn_bias - - correct_causal_mask = np.full((seqlen, total_seqlen), 1) - if attn_mask is not None: - filter_value = -10000.0 - if mask_dim == 4: - # equivalent to past_sequence_length = max_sequence_length - seqlen - converted_mask = (1 - attn_mask[:, :, -seqlen:, :total_seqlen]) * filter_value - else: - converted_mask = (1 - attn_mask.reshape(mask_shape_broadcasted)) * filter_value - pre_softmax_attn_scores = pre_softmax_attn_scores + converted_mask - if causal: - filter_value = np.finfo(dtype).min - causal_mask, correct_causal_mask = _make_causal_mask(seqlen, total_seqlen, pre_softmax_attn_scores.dtype) - causal_mask = np.broadcast_to(causal_mask, pre_softmax_attn_scores.shape) * filter_value - pre_softmax_attn_scores = pre_softmax_attn_scores + causal_mask - attn_scores = softmax(pre_softmax_attn_scores, axis=-1) - - # apply mask to attn_scores to correct softmax result, in c++ implementation, if all values in a row are masked, - # the softmax result in this row will be filled with 0. - correct_causal_mask = np.broadcast_to(correct_causal_mask, pre_softmax_attn_scores.shape) - attn_scores = attn_scores * correct_causal_mask - - attn = matmul(attn_scores, v) - ref = np.swapaxes(attn, 2, 1) # permute 0213 - - out = np.empty(out_shape, dtype=dtype) - host_q, host_k, host_v = maybe_pack_q_k_v_bnsh_for_device_on_host(q, k, v, dtype, qkv_format) - host_attn_bias = attn_bias.astype(dtype) if attn_bias is not None else None - dev_q = ke.DeviceArray(host_q) - dev_k = ke.DeviceArray(host_k) if host_k is not None else None - dev_v = ke.DeviceArray(host_v) if host_v is not None else None - dev_out = ke.DeviceArray(out) - dev_attn_bias = ke.DeviceArray(host_attn_bias) if host_attn_bias is not None else None - dev_attn_mask = ke.DeviceArray(attn_mask) if attn_mask is not None else None - - my_gemm_softmax_gemm_permute = f( - batch, - seqlen, - total_seqlen, - max_seqlen, - num_heads, - head_size, - mask_dim, - scale, - causal, - qkv_format, - dev_q, - dev_k, - dev_v, - dev_attn_bias, - dev_attn_mask, - dev_out, - ) - - print() # write an empty line in case pytest ... -s -v - failures = {} - for impl in my_gemm_softmax_gemm_permute.ListOps(): - if not my_gemm_softmax_gemm_permute.SelectOp(impl): - print("Unsupport", impl) - continue - print(" Support", impl) - - my_gemm_softmax_gemm_permute.Run() - dev_out.UpdateHostNumpyArray() - - try: - is_strict = int(os.environ.get("KERNEL_EXPLORER_STRICT_TEST", "0")) - if is_strict: - # NOTE: this will always fail, just for manual checking with: - # KERNEL_EXPLORER_STRICT_TEST=1 pytest ... -s -v - np.testing.assert_allclose(out, ref) - else: - is_zero_tol, atol, rtol = 1e-3, 2e-2, 1e-2 - not_close_to_zeros = np.abs(ref) > is_zero_tol - np.testing.assert_allclose(out[not_close_to_zeros], ref[not_close_to_zeros], atol=atol, rtol=rtol) - except Exception as err: - header = "*" * 30 + impl + "*" * 30 - print(header) - print(err) - print("*" * len(header)) - failures[impl] = str(err) - - if failures: - raise Exception(failures) - - -@pytest.mark.parametrize("mask_dim", mask_dims, ids=get_mask_dim_id) -@pytest.mark.parametrize("biased", biaseds, ids=get_biased_id) -@pytest.mark.parametrize("head_size", head_sizes) -@pytest.mark.parametrize("nhead", num_heads) -@pytest.mark.parametrize("total_seqlen", total_seqlens) -@pytest.mark.parametrize("seqlen", seqlens) -@pytest.mark.parametrize("batch", [16]) -@pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("dtype", ["float16", "float32"]) -def test_gemm_softmax_gemm_permute_generic( - dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim -): - f = getattr(ke, "GemmSoftmaxGemmPermuteGeneric_" + dtype_to_suffix(dtype)) - scale = 1.0 / np.sqrt(head_size) - _test_gemm_softmax_gemm_permute( - f, - dtype, - batch, - seqlen, - total_seqlen, - nhead, - head_size, - biased, - mask_dim, - scale, - causal, - ke.qkv_format.Q_K_V_BNSH, - ) - - -@pytest.mark.parametrize("mask_dim", [2], ids=get_mask_dim_id) -@pytest.mark.parametrize("biased", [False], ids=get_biased_id) -@pytest.mark.parametrize("head_size", [64]) -@pytest.mark.parametrize("nhead", [8]) -@pytest.mark.parametrize("total_seqlen", [128]) -@pytest.mark.parametrize("seqlen", [64]) -@pytest.mark.parametrize("batch", [16]) -@pytest.mark.parametrize("causal", [True, False]) -@pytest.mark.parametrize("dtype", ["float16", "float32"]) -def test_gemm_softmax_gemm_permute_generic_nested_tunable( - dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim -): - f = getattr(ke, "GemmSoftmaxGemmPermuteGenericNestedTunable_" + dtype_to_suffix(dtype)) - scale = 1.0 / np.sqrt(head_size) - _test_gemm_softmax_gemm_permute( - f, - dtype, - batch, - seqlen, - total_seqlen, - nhead, - head_size, - biased, - mask_dim, - scale, - causal, - ke.qkv_format.Q_K_V_BNSH, - ) - - -@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") -@pytest.mark.parametrize("mask_dim", mask_dims, ids=get_mask_dim_id) -@pytest.mark.parametrize("biased", biaseds, ids=get_biased_id) -@pytest.mark.parametrize("head_size", head_sizes) -@pytest.mark.parametrize("nhead", num_heads) -@pytest.mark.parametrize("total_seqlen", total_seqlens) -@pytest.mark.parametrize("seqlen", seqlens) -@pytest.mark.parametrize("batch", batches) -@pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("dtype", dtypes) -def test_gemm_softmax_gemm_permute_ck(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim): - f = getattr(ke, get_ck_binding_name(dtype, biased, mask_dim != 0)) - scale = 1.0 / np.sqrt(head_size) - _test_gemm_softmax_gemm_permute( - f, - dtype, - batch, - seqlen, - total_seqlen, - nhead, - head_size, - biased, - mask_dim, - scale, - causal, - ke.qkv_format.Q_K_V_BNSH, - ) - - -@pytest.mark.parametrize("mask_dim", [2], ids=get_mask_dim_id) -@pytest.mark.parametrize("biased", [False], ids=get_biased_id) -@pytest.mark.parametrize("head_size", [64]) -@pytest.mark.parametrize("nhead", [8]) -@pytest.mark.parametrize("total_seqlen", [128]) -@pytest.mark.parametrize("seqlen", [64]) -@pytest.mark.parametrize("batch", [16]) -@pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("dtype", ["float16"]) -def test_gemm_softmax_gemm_permute_tunable( - dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim -): - f = getattr(ke, "GemmSoftmaxGemmPermuteTunable_" + dtype_to_suffix(dtype)) - scale = 1.0 / np.sqrt(head_size) - _test_gemm_softmax_gemm_permute( - f, - dtype, - batch, - seqlen, - total_seqlen, - nhead, - head_size, - biased, - mask_dim, - scale, - causal, - ke.qkv_format.Q_K_V_BNSH, - ) - - -stabel_diffusion_configs = [ - [2, 64, 64, 8, 160, "QKV_BSN3H"], - [2, 256, 256, 8, 160, "QKV_BSN3H"], - [2, 1024, 1024, 8, 80, "QKV_BSN3H"], - [2, 4096, 4096, 8, 40, "QKV_BSN3H"], - [2, 64, 77, 8, 160, "Q_KV_BSNH_BSN2H"], - [2, 256, 77, 8, 160, "Q_KV_BSNH_BSN2H"], - [2, 1024, 77, 8, 80, "Q_KV_BSNH_BSN2H"], - [2, 4096, 77, 8, 40, "Q_KV_BSNH_BSN2H"], - [1, 4096, 4096, 1, 512, "Q_K_V_BNSH"], -] - - -@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") -@pytest.mark.parametrize("mask_dim", [0], ids=get_mask_dim_id) -@pytest.mark.parametrize("biased", [False], ids=get_biased_id) -@pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("batch, seqlen, total_seqlen, nhead, head_size, qkv_format_name", stabel_diffusion_configs) -@pytest.mark.parametrize("dtype", dtypes) -def test_gemm_softmax_gemm_permute_ck_sd( - dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim, qkv_format_name -): - qkv_format = getattr(ke.qkv_format, qkv_format_name) - f = getattr(ke, get_ck_binding_name(dtype, biased, mask_dim != 0)) - scale = 1.0 / np.sqrt(head_size) - _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, causal, qkv_format - ) - - -@dataclass -class GemmSoftmaxGemmPermuteMetric(ke.ComputeMetric): - batch: int - seqlen: int - total_seqlen: int - num_heads: int - head_size: int - biased: bool - mask_dim: int - - def report(self): - bias_str = " biased" if self.biased else "" - mask_str = f" mask_{self.mask_dim}d" if self.mask_dim != 0 else "" - common = ( - f"{self.dtype} B={self.batch} S={self.seqlen} T={self.total_seqlen} " - f"N={self.num_heads} H={self.head_size}{bias_str}{mask_str}, " - f"{self.name}" - ) - if self.duration <= 0: - return "not supported " + common - - return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common - - -@ke.dispatchable(pattern_arg=0) -def profile_gemm_softmax_gemm_permute_func( - f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format -): - v_head_size = head_size - q_shape = [batch, num_heads, seqlen, head_size] - k_shape = [batch, num_heads, total_seqlen, head_size] - v_shape = [batch, num_heads, total_seqlen, v_head_size] - out_shape = [batch, seqlen, num_heads, head_size] - - attn_bias = None - bias_shape = [batch, num_heads, seqlen, total_seqlen] if biased else None - - attn_mask = None - mask_shape = None - max_seqlen = None - if mask_dim != 0: - if mask_dim == 2: - mask_shape = [batch, total_seqlen] - elif mask_dim == 3: - mask_shape = [batch, seqlen, total_seqlen] - elif mask_dim == 4: - max_seqlen = ((seqlen - 1) // 1024 + 1) * 1024 # round up to multiple of 1024 - mask_shape = [batch, 1, max_seqlen, max_seqlen] - else: - raise ValueError - - np.random.seed(42) - q = multinormal_distribution(np.prod(q_shape[:-1]), q_shape[-1]).reshape(q_shape).astype(np.float64) - k = multinormal_distribution(np.prod(k_shape[:-1]), k_shape[-1]).reshape(k_shape).astype(np.float64) - v = multinormal_distribution(np.prod(v_shape[:-1]), v_shape[-1]).reshape(v_shape).astype(np.float64) - if bias_shape is not None: - attn_bias = np.random.uniform(-2, 2, size=bias_shape) - if mask_shape is not None: - attn_mask = (np.random.randint(0, 100, size=mask_shape) < 95).astype(np.int32) - - out = np.empty(out_shape, dtype=dtype) - host_q, host_k, host_v = maybe_pack_q_k_v_bnsh_for_device_on_host(q, k, v, dtype, qkv_format) - host_attn_bias = attn_bias.astype(dtype) if attn_bias is not None else None - dev_q = ke.DeviceArray(host_q) - dev_k = ke.DeviceArray(host_k) if host_k is not None else None - dev_v = ke.DeviceArray(host_v) if host_v is not None else None - dev_out = ke.DeviceArray(out) - dev_attn_bias = ke.DeviceArray(host_attn_bias) if host_attn_bias is not None else None - dev_attn_mask = ke.DeviceArray(attn_mask) if attn_mask is not None else None - - my_gemm_softmax_gemm_permute = f( - batch, - seqlen, - total_seqlen, - max_seqlen, - num_heads, - head_size, - mask_dim, - scale, - causal, - qkv_format, - dev_q, - dev_k, - dev_v, - dev_attn_bias, - dev_attn_mask, - dev_out, - ) - - for impl in my_gemm_softmax_gemm_permute.ListOps(): - duration_ms = -1 - if my_gemm_softmax_gemm_permute.SelectOp(impl): - duration_ms = my_gemm_softmax_gemm_permute.Profile() - - m, n, k, o, gemm_batch = seqlen, total_seqlen, head_size, head_size, batch * num_heads - flops_per_batch = m * n * k * 2 + m * n * o * 2 - flops_count_bias_and_softmax = True # set to false to be aligned with ck - if flops_count_bias_and_softmax: - flops_per_batch += 2 * n + 1 - if flops_count_bias_and_softmax and attn_bias is not None: - flops_per_batch += m * n - if flops_count_bias_and_softmax and attn_mask is not None: - flops_per_batch += m * n - flops = flops_per_batch * gemm_batch - - ke.report( - GemmSoftmaxGemmPermuteMetric( - impl, dtype, duration_ms, flops, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim - ) - ) - - -@ke.dispatchable -def profile_with_args( - dtype, - batch, - seqlen, - total_seqlen, - num_heads, - head_size, - biased, - causal, - mask_dim, - scale, - qkv_format, -): - with ke.benchmark(): - args = (dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format) - if qkv_format == ke.qkv_format.Q_K_V_BNSH: - profile_gemm_softmax_gemm_permute_func( - getattr(ke, "GemmSoftmaxGemmPermuteGeneric_" + dtype_to_suffix(dtype)), *args - ) - if ke.is_composable_kernel_available(): - profile_gemm_softmax_gemm_permute_func( - getattr(ke, get_ck_binding_name(dtype, biased, mask_dim != 0)), *args - ) - profile_gemm_softmax_gemm_permute_func( - getattr(ke, "GemmSoftmaxGemmPermuteTunable_" + dtype_to_suffix(dtype)), *args - ) - - -def profile(): - for batch, seqlen, total_seqlen, nhead, head_size, qkv_format_name in stabel_diffusion_configs: - profile_with_args( - "float16", - batch, - seqlen, - total_seqlen, - nhead, - head_size, - biased=False, - causal=False, - mask_dim=0, - qkv_format=getattr(ke.qkv_format, qkv_format_name), - scale=0.125, - ) - print() - - for args in product(dtypes, batches, seqlens, total_seqlens, num_heads, head_sizes, biaseds, causals, mask_dims): - profile_with_args(*args, qkv_format=ke.qkv_format.Q_K_V_BNSH, scale=0.125) - print() - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("dtype", choices=dtypes) - group.add_argument("batch", type=int) - group.add_argument("seqlen", type=int) - group.add_argument("total_seqlen", type=int) - group.add_argument("num_heads", type=int) - group.add_argument("head_size", type=int) - group.add_argument("biased", type=int, choices=[0, 1], default=0) - group.add_argument("mask_dim", type=int, choices=[0, 2, 3, 4], default=2, help="0 for mask disabled") - group.add_argument("causal", type=int, choices=[0, 1], default=0) - group.add_argument("--scale", type=float, default=None, help="default to 1.0/sqrt(head_size)") - group.add_argument( - "--qkv_format", - default="Q_K_V_BNSH", - choices=[ - "Q_K_V_BNSH", # non-packed, permuted - "Q_K_V_BSNH", # non-packed, non-permuted - "Q_KV_BSNH_BSN2H", # kv packed, non-permuted - "QKV_BSN3H", # qkv packed, non-permuted - ], - ) - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch( - args.dtype, - args.batch, - args.seqlen, - args.total_seqlen, - args.num_heads, - args.head_size, - args.biased, - args.causal, - args.mask_dim, - args.scale, - getattr(ke.qkv_format, args.qkv_format), - ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py deleted file mode 100644 index 23ffa5735d2c1..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py +++ /dev/null @@ -1,217 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from dataclasses import dataclass -from itertools import product - -import kernel_explorer as ke -import numpy as np -import pytest -from utils import dtype_to_suffix, get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix - - -@ke.dispatchable -def _test_gemm(func, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, alpha=1.0, beta=0.0): - assert dtype in ["float32", "float16", "float8_e4m3"] - - a_shape = (k, m) if transa else (m, k) - b_shape = (n, k) if transb else (k, n) - - np.random.seed(0) - a = (np.random.rand(*a_shape) + 0.5).astype(dtype).astype("float64") - b = (np.random.rand(*b_shape) + 0.5).astype(dtype).astype("float64") - intermediate_c = matmul(a, b, transa, transb) - if alpha == 1.0 and beta == 0.0: # fast path - ref_c = intermediate_c - else: - ref_c = alpha * intermediate_c + beta * np.ones_like(intermediate_c) - - bound = get_gemm_bound(dtype, a, b, ref_c, transa, transb, a_b_positive=True) - - a = a.astype(dtype) - b = b.astype(dtype) - - my_c = np.ones((m, n), dtype=dtype) - dev_a = ke.DeviceArray(a) - dev_b = ke.DeviceArray(b) - dev_c = ke.DeviceArray(my_c) - - opa = ke.blas_op.T if transa else ke.blas_op.N - opb = ke.blas_op.T if transb else ke.blas_op.N - lda = a_shape[1] - ldb = b_shape[1] - my_gemm = func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, beta, dev_c, n) - - failures = {} - print(f"dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} bound: {bound}") - - for impl in my_gemm.ListOps(): - if not my_gemm.SelectOp(impl): - continue - # Restore C Array - my_c.fill(1.0) - dev_c.UpdateDeviceArray() - my_gemm.Run() - dev_c.UpdateHostNumpyArray() - - try: - np.testing.assert_allclose(my_c, ref_c, rtol=bound) - except Exception as err: - header = "*" * 30 + impl + "*" * 30 - print(header) - print(err) - print("*" * len(header)) - failures[impl] = str(err) - - if failures: - raise Exception(failures) - - -dtypes = ["float32", "float16"] -all_transabs = list(product([True, False], repeat=2)) - - -@pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=True) + get_gemm_bert_sizes(full=False)) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -@ke.dispatchable -def test_rocblas_gemm_all_cases(dtype, transa, transb, m, n, k): - _test_gemm(getattr(ke, "RocBlasGemm_" + dtype_to_suffix(dtype)), dtype, transa, transb, m, n, k) - - -@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") -@pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -@ke.dispatchable -def test_ck_gemm_bert_cases(dtype, transa, transb, m, n, k): - wrapper_name = f"CKGemm_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" - _test_gemm(getattr(ke, wrapper_name), dtype, transa, transb, m, n, k) - - -# Tunable is basically wrapped around of rocblas and ck gemm, so no need for full tests -@pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -@ke.dispatchable -def test_gemm_tunable_bert_cases(dtype, transa, transb, m, n, k): - wrapper_name = f"GemmTunable_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" - _test_gemm(getattr(ke, wrapper_name), dtype, transa, transb, m, n, k) - - -@pytest.mark.parametrize("alpha, beta", [(0.5, 0.5)]) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -def test_rocblas_gemm_alpha_beta(dtype, transa, transb, alpha, beta): - wrapper_name = "RocBlasGemm_" + dtype_to_suffix(dtype) - _test_gemm(getattr(ke, wrapper_name), dtype, transa, transb, 128, 256, 768, alpha=alpha, beta=beta) - - -@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") -@pytest.mark.parametrize("alpha, beta", [(0.5, 0.5)]) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -def test_ck_gemm_alpha_beta(dtype, transa, transb, alpha, beta): - wrapper_name = f"CKGemm_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" - _test_gemm(getattr(ke, wrapper_name), dtype, transa, transb, 256, 128, 384, alpha=alpha, beta=beta) - - -@pytest.mark.parametrize("alpha, beta", [(0.5, 0.5)]) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -def test_gemm_tunable_alpha_beta(dtype, transa, transb, alpha, beta): - wrapper_name = f"GemmTunable_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" - _test_gemm(getattr(ke, wrapper_name), dtype, transa, transb, 128, 512, 384, alpha=alpha, beta=beta) - - -@dataclass -class GemmMetric(ke.ComputeMetric): - transa: bool - transb: bool - m: int - n: int - k: int - - def report(self): - common = ( - f"{self.dtype} {transab_to_suffix((self.transa, self.transb))} " - f"m={self.m:<4} n={self.n:<4} k={self.k:<4} {self.name}" - ) - if self.duration <= 0: - return "not supported " + common - - return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common - - -@ke.dispatchable(pattern_arg=0) -def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int): - a_shape = (k, m) if transa else (m, k) - b_shape = (n, k) if transb else (k, n) - - np.random.seed(0) - a = (np.random.rand(*a_shape) * 2 - 1).astype(dtype) - b = (np.random.rand(*b_shape) * 2 - 1).astype(dtype) - my_c = np.zeros((m, n), dtype=dtype) - - dev_a = ke.DeviceArray(a) - dev_b = ke.DeviceArray(b) - dev_c = ke.DeviceArray(my_c) - - opa = ke.blas_op.T if transa else ke.blas_op.N - opb = ke.blas_op.T if transb else ke.blas_op.N - lda = a_shape[1] - ldb = b_shape[1] - alpha = 1.0 - beta = 0.0 - my_gemm = f(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, beta, dev_c, n) - - for impl in my_gemm.ListOps(): - duration_ms = -1 - if my_gemm.SelectOp(impl): - duration_ms = my_gemm.Profile() - FLOPs = m * k * n * 2 # noqa: N806 - - ke.report(GemmMetric(impl, dtype, duration_ms, FLOPs, transa, transb, m, n, k)) - - -@ke.dispatchable -def profile_with_args(dtype, transa, transb, m, n, k): - dtype_suffix = "_" + dtype_to_suffix(dtype) - transab_suffix = "_" + transab_to_suffix((transa, transb)) - with ke.benchmark(): - if ke.is_rocm_available(): - profile_gemm_func(getattr(ke, "RocBlasGemm" + dtype_suffix), dtype, transa, transb, m, n, k) - profile_gemm_func(getattr(ke, "CKGemm" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k) - profile_gemm_func(getattr(ke, "GemmTunable" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k) - if ke.is_cuda_available(): - profile_gemm_func(getattr(ke, "GemmBenchmark" + dtype_suffix), dtype, transa, transb, m, n, k) - if ke.is_hipblaslt_available(): - profile_gemm_func( - getattr(ke, "GemmHipBlasLt" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k - ) - print() - - -def profile(): - for dtype in dtypes: - for m, n, k in get_gemm_bert_sizes(full=True): - profile_with_args(dtype, False, False, m, n, k) - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("dtype", choices=dtypes) - group.add_argument("transa", choices="NT") - group.add_argument("transb", choices="NT") - group.add_argument("m", type=int) - group.add_argument("n", type=int) - group.add_argument("k", type=int) - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch(args.dtype, args.transa == "T", args.transb == "T", args.m, args.n, args.k) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py deleted file mode 100644 index a45b9e80500cc..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py +++ /dev/null @@ -1,321 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import re -from dataclasses import dataclass -from itertools import product - -import kernel_explorer as ke -import numpy as np -import pytest -from utils import dtype_to_bytes, dtype_to_suffix, standardization - - -def get_sd_sizes(): - batch_sizes = [1, 2] - height = [8, 16, 32] - num_channels = [320, 640, 1280, 1920, 2560] - - num_groups = [32] - return product(batch_sizes, height, num_channels, num_groups) - - -def dtype_to_funcs(dtype): - type_map = { - "float16": list(filter(lambda x: re.match("GroupNormNHWC.*_half", x), dir(ke))), - "float32": list(filter(lambda x: re.match("GroupNormNHWC.*_float", x), dir(ke))), - } - return type_map[dtype] - - -def sigmoid_function(x): - return 1.0 / (1.0 + np.exp(-x)) - - -def group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, with_silu, has_skip): - add_output = None - if has_skip: - input_x = input_x + skip_x + bias_x - add_output = input_x - n, h, w, c = input_x.shape - input_x = input_x.transpose([0, 3, 1, 2]) - assert c % num_groups == 0 - x = input_x.reshape((n, num_groups, -1)) - x = standardization(x, -1, epsilon) - x = x.reshape((n, c, h, w)) - x = x.transpose([0, 2, 3, 1]) - x = x * gamma + beta - - if with_silu: - x = x * sigmoid_function(x) - return x, add_output - - -def run_group_norm( - batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, silu: bool, has_skip: bool, func -): - np.random.seed(0) - width = height - input_x = np.random.rand(batch_size, height, width, num_channels).astype(np.float32) - gamma = np.random.rand(num_channels).astype(np.float32) - beta = np.random.rand(num_channels).astype(np.float32) - # the size of workspace is defined in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h L18 - workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * batch_size * num_groups).astype(np.float32) - epsilon = 1e-05 - output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - - skip_x = ( - np.random.rand(batch_size, height, width, num_channels).astype(np.float32) - if has_skip - else np.empty((0), dtype=dtype) - ) - bias_x = np.random.rand(num_channels).astype(np.float32) if has_skip else np.empty((0), dtype=dtype) - add_output = ( - np.random.rand(batch_size, height, width, num_channels).astype(dtype) - if has_skip - else np.empty((0), dtype=dtype) - ) - use_silu = silu - broadcast_skip = False - if has_skip: - skip_x_shape = skip_x.shape - b2 = len(skip_x_shape) == 2 and skip_x_shape[0] == batch_size and skip_x_shape[1] == num_channels - b4 = ( - len(skip_x_shape) == 4 - and skip_x_shape[0] == batch_size - and skip_x_shape[1] == 1 - and skip_x_shape[2] == 1 - and skip_x_shape[3] == num_channels - ) - if b2 or b4: - broadcast_skip = True - channels_per_block = 0 # Compute in params initialization - - input_d = ke.DeviceArray(input_x.astype(dtype)) - skip_d = ke.DeviceArray(skip_x.astype(dtype)) - bias_d = ke.DeviceArray(bias_x.astype(dtype)) - gamma_d = ke.DeviceArray(gamma) - beta_d = ke.DeviceArray(beta) - workspace_d = ke.DeviceArray(workspace) - y_d = ke.DeviceArray(output_y) - y_add_d = ke.DeviceArray(add_output) - f = getattr(ke, func) - - my_op = f( - y_d, - y_add_d, - input_d, - skip_d, - bias_d, - gamma_d, - beta_d, - workspace_d, - epsilon, - batch_size, - num_channels, - height, - width, - num_groups, - use_silu, - broadcast_skip, - channels_per_block, - ) - y_ref, y_add_d_ref = group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, use_silu, has_skip) - y_ref = y_ref.astype(dtype) - - for impl in my_op.ListOps(): - if not my_op.SelectOp(impl): - continue - - my_op.Run() - - y_d.UpdateHostNumpyArray() - - np.testing.assert_allclose(y_ref, output_y, atol=1e-02) - if has_skip: - y_add_d_ref = y_add_d_ref.astype(dtype) - y_add_d.UpdateHostNumpyArray() - np.testing.assert_allclose(y_add_d_ref, add_output, atol=1e-02) - - -dtypes = ["float32", "float16"] - - -@pytest.mark.parametrize("sd_sizes", get_sd_sizes()) -@pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("silu", [True]) -@pytest.mark.parametrize("has_skip", [True, False]) -def test_group_norm(sd_sizes, dtype, silu, has_skip): - for func in dtype_to_funcs(dtype): - run_group_norm(*sd_sizes, dtype, silu, has_skip, func) - - -@pytest.mark.parametrize("sd_sizes", get_sd_sizes()) -@pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("silu", [True]) -@pytest.mark.parametrize("has_skip", [False]) -def test_group_norm_ck(sd_sizes, dtype, silu, has_skip): - silu_suffix = "Silu" if silu else "Pass" - ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype) - run_group_norm(*sd_sizes, dtype, silu, has_skip, ck_f_name) - - -@dataclass -class GroupNormNHWCMetric(ke.BandwidthMetric): - batch_size: int - height: int - width: int - num_channels: int - groups: int - - def report(self): - common = ( - f"{self.dtype:<4} batch={self.batch_size:<4} height={self.height:<4} width={self.width:<4}" - f"num_channels={self.num_channels:<6} groups={self.groups:<4} {self.name}" - ) - if self.duration > 0: - return f"{self.duration:>6.2f} us, {self.gbps:>5.2f} GB/s " + common - return "not supported " + common - - -@ke.dispatchable(pattern_arg=8) -def profile_group_norm_func( - batch_size: int, - height: int, - width: int, - num_channels: int, - num_groups: int, - dtype: str, - silu: bool, - has_skip: bool, - func, -): - np.random.seed(0) - input_x = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - gamma = np.random.rand(num_channels).astype(np.float32) - beta = np.random.rand(num_channels).astype(np.float32) - workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * batch_size * num_groups).astype(np.float32) - epsilon = 0.05 - output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - - skip_x = ( - np.random.rand(batch_size, height, width, num_channels).astype(dtype) - if has_skip - else np.empty((0), dtype=dtype) - ) - bias_x = np.random.rand(num_channels).astype(dtype) if has_skip else np.empty((0), dtype=dtype) - add_output = ( - np.random.rand(batch_size, height, width, num_channels).astype(dtype) - if has_skip - else np.empty((0), dtype=dtype) - ) - use_silu = silu - broadcast_skip = False - channels_per_block = 0 # Compute in params initialization - - input_d = ke.DeviceArray(input_x) - skip_d = ke.DeviceArray(skip_x) - bias_d = ke.DeviceArray(bias_x) - gamma_d = ke.DeviceArray(gamma) - beta_d = ke.DeviceArray(beta) - workspace_d = ke.DeviceArray(workspace) - y_d = ke.DeviceArray(output_y) - y_add_d = ke.DeviceArray(add_output) - f = getattr(ke, func) - - my_op = f( - y_d, - y_add_d, - input_d, - skip_d, - bias_d, - gamma_d, - beta_d, - workspace_d, - epsilon, - batch_size, - num_channels, - height, - width, - num_groups, - use_silu, - broadcast_skip, - channels_per_block, - ) - for impl in my_op.ListOps(): - duration_ms = -1 - if my_op.SelectOp(impl): - duration_ms = my_op.Profile() - total_bytes = (input_x.size * 2 + gamma.size * 2) * dtype_to_bytes(dtype) - - ke.report( - GroupNormNHWCMetric( - impl, dtype, duration_ms, total_bytes, batch_size, height, width, num_channels, num_groups - ) - ) - - -@ke.dispatchable -def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, silu=True, has_skip=True): - with ke.benchmark(): - for func in dtype_to_funcs(dtype): - profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, func) - # ck function - silu_suffix = "Silu" if silu else "Pass" - ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype) - profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, ck_f_name) - - -sd_profile_sizes = [ - (2, 64, 64, 320, 32), - (2, 32, 32, 640, 32), - (2, 16, 16, 1280, 32), - (2, 64, 64, 640, 32), - (2, 16, 16, 2560, 32), - (2, 32, 32, 1280, 32), - (2, 32, 32, 1920, 32), - (2, 8, 8, 1280, 32), - (2, 64, 64, 960, 32), - (2, 32, 32, 960, 32), - (2, 32, 32, 320, 32), - (2, 16, 16, 640, 32), - (2, 16, 16, 1920, 32), - (2, 8, 8, 2560, 32), -] - - -def profile(): - for dtype in dtypes: - for sd_size in sd_profile_sizes: - profile_with_args(*sd_size, dtype) - print() - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("batch_size", type=int) - group.add_argument("height", type=int) - group.add_argument("width", type=int) - group.add_argument("num_channels", type=int) - group.add_argument("num_groups", type=int) - group.add_argument("dtype", choices=dtypes) - group.add_argument("--silu", action="store_true") - group.add_argument("--has_skip", action="store_true") - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch( - args.batch_size, - args.height, - args.width, - args.num_channels, - args.num_groups, - args.dtype, - args.silu, - args.has_skip, - ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py b/onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py deleted file mode 100644 index 8be8481fd1394..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py +++ /dev/null @@ -1,364 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -"""This file provides wrapper for native _kernel_explorer.so library and benchmark reporter for operator""" - -from __future__ import annotations - -import ctypes -import json -import os -import sys -from abc import abstractmethod -from argparse import Action, ArgumentParser -from collections.abc import Callable -from contextlib import contextmanager -from dataclasses import dataclass -from fnmatch import fnmatch -from functools import wraps - -build_dir = os.environ.get("KERNEL_EXPLORER_BUILD_DIR", None) -if build_dir is None: - raise ValueError("Environment variable KERNEL_EXPLORER_BUILD_DIR is required") - -if not os.path.exists(build_dir): - raise ValueError(f"KERNEL_EXPLORER_BUILD_DIR ({build_dir}) points to nonexistent path") - -# onnxruntime_pybind11_state and kernel_explorer -sys.path.insert(0, build_dir) - -# pylint: disable=wrong-import-position -import onnxruntime_pybind11_state # noqa: E402 - -# We need to call some functions to properly initialize so pointers in the library -available_providers = onnxruntime_pybind11_state.get_available_providers() - - -build_dir = os.path.realpath(build_dir) -search_paths = [build_dir] - -# As Kernel Explorer makes use of utility functions in ONNXRuntime, we dlopen all relevant libraries to bring required -# symbols into global namespace, so that we don't need to worry about linking. -library_files_to_load = [ - "onnxruntime_pybind11_state.so", - "libonnxruntime_providers_shared.so", -] -_is_cuda_available = False -_is_rocm_available = False -if "CUDAExecutionProvider" in available_providers: - library_files_to_load.append("libonnxruntime_providers_cuda.so") - _is_cuda_available = True -if "ROCMExecutionProvider" in available_providers: - library_files_to_load.append("libonnxruntime_providers_rocm.so") - _is_rocm_available = True - -library_to_load = [] - -for lib in library_files_to_load: - for prefix in search_paths: - path = os.path.join(prefix, lib) - if os.path.exists(path): - library_to_load.append(path) - continue - - raise OSError(f"cannot found {lib}") - - -# use RTLD_GLOBAL to bring all symbols to global name space -_libraries = [ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) for lib_path in library_to_load] -del library_files_to_load, library_to_load - -# pylint: disable=wrong-import-position, disable=unused-import -import _kernel_explorer # noqa: E402 - -# pylint: disable=wrong-import-position, disable=unused-import, disable=wildcard-import -from _kernel_explorer import * # noqa: F403, E402 - - -@dataclass -class _KeContext: - sort: bool = False - - pattern = "*" - - # mapping the module to dispatch to - dispatchable: dict | None = None - instance_dispatchable: dict | None = None # can be filtered with pattern - - dispatch_depth = 0 - - save_tuning_results: str | None = None - return_tuning_results: bool = False - - -_ke_context = _KeContext() -_ke_context.dispatchable = {} -_ke_context.instance_dispatchable = {} - - -# Benchmark Reporter -@dataclass -class MetricBase: - name: str - dtype: str - milliseconds_duration: float - - def __lt__(self, other): - if "Tunable" in self.name or other.duration < 0: - return True - if "Tunable" in other.name or self.duration < 0: - return False - - return self.duration < other.duration - - @property - def duration(self): - return self.milliseconds_duration * 1000 - - @abstractmethod - def report(self) -> str: - raise NotImplementedError() - - -@dataclass -class ComputeMetric(MetricBase): - FLOPs: int - - @property - def tflops(self): - return self.FLOPs * 1e6 / self.duration / 1e12 - - -@dataclass -class BandwidthMetric(MetricBase): - bytes: int - - @property - def gbps(self): - return self.bytes * 1e6 / self.duration / 1e9 - - -@dataclass -class ComputeAndBandwidthMetric(ComputeMetric, BandwidthMetric): - pass - - -class InstanceBenchmarkReporter: - def __init__(self): - self.best = float("inf") - self.reporters = [] - - def make_report(self): - self.reporters.sort() - for item in self.reporters: - if not _ke_context.sort and item.milliseconds_duration > 0 and item.milliseconds_duration < self.best: - self.best = item.milliseconds_duration - print(item.report(), "*") - else: - print(item.report()) - self.reporters.clear() - - def receive(self, status): - self.reporters.append(status) - if not _ke_context.sort: - self.make_report() - - def _reset_best(self): - self.best = float("inf") - - -_reporter = InstanceBenchmarkReporter() - - -@contextmanager -def benchmark(): - _reporter._reset_best() - try: - yield - finally: - _reporter.make_report() - - -def report(status): - _reporter.receive(status) - - -def set_ort_severity(v): - v = int(v) - onnxruntime_pybind11_state.set_default_logger_severity(v) - return v - - -def set_ort_verbosity(v): - v = int(v) - onnxruntime_pybind11_state.set_default_logger_verbosity(v) - return v - - -def register_common_arguments(parser: ArgumentParser): - class SortAction(Action): - def __init__(self, option_strings, dest, default=False, help=None): - super().__init__(option_strings=option_strings, dest=dest, nargs=0, default=default, help=help) - - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, True) - _ke_context.sort = True - - def set_dispatch(name): - if name in _ke_context.dispatchable: - dispatch = _ke_context.dispatchable[name] - _ke_context.dispatch = dispatch - return dispatch - - if name in _ke_context.instance_dispatchable: - msg = f"'{name}' needs an instance to dispatch, thus it is not dispatchable from commandline." - print(msg) - raise ValueError(msg) - - from difflib import SequenceMatcher as Matcher - - valid_names = list(_ke_context.dispatchable.keys()) - scored_names = sorted([(Matcher(None, name, a).ratio(), a) for a in valid_names], reverse=True) - top10 = "\n ".join([a for _, a in scored_names[:10]]) - msg = f"'{name}' is not registered for dispatch. Top 10 matches are:\n {top10}" - print(msg) - raise ValueError(msg) - - def set_pattern(pattern): - pattern = str(pattern) - _ke_context.pattern = pattern - - def set_save_tuning_results(path): - _ke_context.save_tuning_results = path - return path - - group = parser.add_argument_group("kernel explorer args", "Common arguments for kernel explorer") - group.add_argument( - "--sort", - action=SortAction, - help="control the sort of ke benchmark results based on timing", - ) - group.add_argument( - "--ort_default_logger_severity", - default=2, - choices=[0, 1, 2, 3, 4], - type=set_ort_severity, - help="0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal", - ) - group.add_argument("--ort_default_logger_verbosity", default=0, type=set_ort_verbosity) - group.add_argument( - "--dispatch", - default="profile_with_args", - help="dispatch a registered dispatchable.", - type=set_dispatch, - ) - group.add_argument( - "--pattern", - default="*", - help="filter the register instanced dispatchables, only matched pattern will be run.", - type=set_pattern, - ) - group.add_argument( - "--save_tuning_results", - default=None, - type=set_save_tuning_results, - help="patch the dispatch function to save tuning results to the specified path.", - ) - - return parser - - -def get_argument_parser(): - parser = ArgumentParser() - return register_common_arguments(parser) - - -def has_args(): - if "--help" in sys.argv or "-h" in sys.argv or "--func" in sys.argv: - return True - - # parse the KE args group - parser = get_argument_parser() - _, remainder = parser.parse_known_args(sys.argv) - return len(remainder) > 1 # the file path is always the remainder - - -def is_cuda_available(): - return _is_cuda_available - - -def is_rocm_available(): - return _is_rocm_available - - -def dispatchable(f: Callable | None = None, *, pattern_arg: int | None = None): - def wrap_dispatch(f, *args, **kwargs): - if _ke_context.dispatch_depth == 0: - if _ke_context.save_tuning_results is not None: - _kernel_explorer.enable_collect_tuning_results() - _ke_context.dispatch_depth += 1 - ret = f(*args, **kwargs) - _ke_context.dispatch_depth -= 1 - if _ke_context.dispatch_depth == 0: - if _ke_context.save_tuning_results is not None: - try: - trs = _kernel_explorer.get_collected_tuning_results() - with open(_ke_context.save_tuning_results, "x") as f: - json.dump(trs, f) - finally: - pass - - if _ke_context.return_tuning_results: - if ret is not None: - print( - f"WARNING: kernel explorer wants to override the return value of {f.__name__},", - "but original return value is not None!", - ) - return ret - try: - trs = _kernel_explorer.get_collected_tuning_results() - return trs - finally: - pass - - return ret - - if f is None: # Used with ke.dispatchable(...) - assert pattern_arg is not None - - def decorator(f): - _ke_context.instance_dispatchable[f.__name__] = f - - @wraps(f) - def wrapper(*args, **kwargs): - func_name = args[pattern_arg] if isinstance(args[pattern_arg], str) else args[pattern_arg].__name__ - if not fnmatch(func_name, _ke_context.pattern): - print( - f"Trying to run {func_name},", - f"does not match allowed function name pattern '{_ke_context.pattern}', skip...", - ) - return - return wrap_dispatch(f, *args, **kwargs) - - return wrapper - - return decorator - - else: # Used with @ke.dispatchable - _ke_context.dispatchable[f.__name__] = f - - @wraps(f) - def wrapper(*args, **kwargs): - return wrap_dispatch(f, *args, **kwargs) - - return wrapper - - -def set_dispatchable_pattern(p: str = "*"): - _ke_context.pattern = p - - -def set_return_tuning_results(b: bool = True): - _ke_context.return_tuning_results = b diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py deleted file mode 100644 index c6f20dcb71e6c..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py +++ /dev/null @@ -1,138 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from dataclasses import dataclass - -import kernel_explorer as ke -import numpy as np -from utils import dtype_to_bytes - - -def dtype_to_funcs(dtype): - type_map = { - "float16": list(filter(lambda x: "MatrixFloatInt4_half" in x, dir(ke))), - "float32": list(filter(lambda x: "MatrixFloatInt4_float" in x, dir(ke))), - } - return type_map[dtype] - - -def dtype_to_funcs_cublas(dtype): - type_map = { - "float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))), - "float32": list(filter(lambda x: "GemmBenchmark_float" in x, dir(ke))), - } - return type_map[dtype] - - -dtypes = ["float16", "float32"] - - -@dataclass -class MatrixMulMetric(ke.BandwidthMetric): - m: int - n: int - k: int - - def report(self): - return ( - f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" - ) - - -@dataclass -class MatrixFpInt4Metric(MatrixMulMetric): - is_symmetric: bool - - def report(self): - return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} is_symmetric={self.is_symmetric} {self.name}" - - -@ke.dispatchable(pattern_arg=4) -def profile_matmul_fp_int4_func(m, n, k, dtype, func, is_symmetric): - np.random.seed(0) - output = np.random.rand(m, n).astype(dtype) - a = np.random.rand(m, k).astype(dtype) - b = np.random.randint(low=0, high=127, size=(n, (k + 31) // 32, 16)).astype("uint8") - scales = np.random.rand(n * ((k + 31) // 32)).astype(dtype) - zeropoints = np.random.rand(n * (((k + 31) // 32 + 1) // 2)).astype(dtype) - - output_d = ke.DeviceArray(output) - a_d = ke.DeviceArray(a) - b_d = ke.DeviceArray(b) - scales_d = ke.DeviceArray(scales) - zeropoints_d = ke.DeviceArray(zeropoints) - f = getattr(ke, func) - - my_op = ( - f(output_d, a_d, b_d, scales_d, m, n, k) - if is_symmetric - else f(output_d, a_d, b_d, scales_d, zeropoints_d, m, n, k) - ) - duration_ms = my_op.Profile() - total_bytes = (m * k + m * n) * (dtype_to_bytes(dtype)) + n * k / 2 - - ke.report(MatrixFpInt4Metric(func, dtype, duration_ms, total_bytes, m, n, k, is_symmetric)) - - -@ke.dispatchable(pattern_arg=4) -def profile_gemm_func(m, n, k, dtype, func): - np.random.seed(0) - output = np.random.rand(m, n).astype(dtype) - a = np.random.rand(m, k).astype(dtype) - b = np.random.rand(k, n).astype(dtype) - - output_d = ke.DeviceArray(output) - a_d = ke.DeviceArray(a) - b_d = ke.DeviceArray(b) - f = getattr(ke, func) - my_op = f(output_d, a_d, b_d, m, n, k) - duration_ms = my_op.Profile() - total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) - - ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k)) - - -@ke.dispatchable -def profile_with_args(m, n, k, dtype): - with ke.benchmark(): - for func in dtype_to_funcs(dtype): - profile_matmul_fp_int4_func(m, n, k, dtype, func, True) - - for func in dtype_to_funcs(dtype): - profile_matmul_fp_int4_func(m, n, k, dtype, func, False) - - for func in dtype_to_funcs_cublas(dtype): - profile_gemm_func(m, n, k, dtype, func) - - -def profile(): - dims_m = [1] - for dt in dtypes: - for m in dims_m: - for n, k in ( - (4096, 4096), - (4096, 12288), - (12288, 4096), - (4096, 11008), - (11008, 4096), - (2 * 11008, 4096), - ): - profile_with_args(m, n, k, dt) - print() - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("m", type=int) - group.add_argument("n", type=int) - group.add_argument("k", type=int) - group.add_argument("dtype", choices=dtypes) - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch(args.m, args.n, args.k, args.dtype) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py deleted file mode 100644 index 4a9489050fd61..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py +++ /dev/null @@ -1,136 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import sys -from dataclasses import dataclass - -import kernel_explorer as ke -import numpy as np -from utils import dtype_to_bytes - - -def dtype_to_funcs(dtype): - type_map = { - "float16": list(filter(lambda x: "MatrixFloatBnb4_half" in x, dir(ke))), - "float32": list(filter(lambda x: "MatrixFloatBnb4_float" in x, dir(ke))), - } - return type_map[dtype] - - -def dtype_to_funcs_cublas(dtype): - type_map = { - "float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))), - "float32": list(filter(lambda x: "GemmBenchmark_float" in x, dir(ke))), - } - return type_map[dtype] - - -quant_enums = {"FP4": 0, "NF4": 1} - - -dtypes = ["float16", "float32"] -quant_types = ["FP4", "NF4"] - - -@dataclass -class MatrixMulMetric(ke.BandwidthMetric): - m: int - n: int - k: int - - def report(self): - return ( - f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" - ) - - -@dataclass -class MatrixFpBnb4Metric(MatrixMulMetric): - quant_type: str - - def report(self): - return ( - f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s" - f" {self.quant_type} {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" - ) - - -def profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func): - np.random.seed(0) - block_size = 64 - numel = n * k - output = np.random.rand(m, n).astype(dtype) - a = np.random.rand(m, k).astype(dtype) - b = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8") - absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype) - quant_map_buffer = np.zeros(16).astype(dtype) - - output_d = ke.DeviceArray(output) - a_d = ke.DeviceArray(a) - b_d = ke.DeviceArray(b) - absmax_d = ke.DeviceArray(absmax) - quant_map_buffer_d = ke.DeviceArray(quant_map_buffer) - f = getattr(ke, func) - - my_op = f(output_d, a_d, b_d, absmax_d, quant_map_buffer_d, quant_enums[qt], m, n, k) - duration_ms = my_op.Profile() - total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) - - ke.report(MatrixFpBnb4Metric(func, dtype, duration_ms, total_bytes, m, n, k, qt)) - - -def profile_gemm_func(m, n, k, dtype, func): - np.random.seed(0) - output = np.random.rand(m, n).astype(dtype) - a = np.random.rand(m, k).astype(dtype) - b = np.random.rand(k, n).astype(dtype) - - output_d = ke.DeviceArray(output) - a_d = ke.DeviceArray(a) - b_d = ke.DeviceArray(b) - f = getattr(ke, func) - my_op = f(output_d, a_d, b_d, m, n, k) - duration_ms = my_op.Profile() - total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) - - ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k)) - - -def profile_with_args(qt, m, n, k, dtype, sort): - with ke.benchmark(sort): - for func in dtype_to_funcs(dtype): - profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func) - - for func in dtype_to_funcs_cublas(dtype): - profile_gemm_func(m, n, k, dtype, func) - - -def profile(): - dims_m = [1] - for qt in quant_types: - for dt in dtypes: - for m in dims_m: - for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): - profile_with_args(qt, m, n, k, dt, False) - print() - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") - group.add_argument("m", type=int) - group.add_argument("n", type=int) - group.add_argument("k", type=int) - group.add_argument("quant_type", choices=quant_types) - group.add_argument("dtype", choices=dtypes) - group.add_argument("--sort", action="store_true") - - if len(sys.argv) == 1: - profile() - else: - args = parser.parse_args() - profile_with_args(args.quant_type, args.m, args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/elementwise.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/elementwise.cu deleted file mode 100644 index 2151a8cc272c6..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/elementwise.cu +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include "contrib_ops/rocm/bert/elementwise.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -namespace py = pybind11; -using namespace onnxruntime::contrib::rocm; - -namespace onnxruntime { - -template -class Elementwise : public IKernelExplorer { - public: - Elementwise(DeviceArray& input, DeviceArray& bias, DeviceArray& output, int input_length, int bias_length) - : params_(TuningContext(), Stream(), static_cast(input.ptr()), static_cast(bias.ptr()), - static_cast(output.ptr()), input_length, bias_length) {} - - bool IsSupported() { - Status status = op_.IsSupported(¶ms_); - return status.IsOK(); - } - - void Run() override { - ORT_THROW_IF_ERROR(op_(¶ms_)); - } - - private: - using ParamsT = internal::ElementwiseParams; - ParamsT params_{}; - internal::ElementwiseOp op_{}; -}; - -template -class ElementwiseStaticSelection : public IKernelExplorer { - public: - ElementwiseStaticSelection(DeviceArray& input, DeviceArray& bias, DeviceArray& output, int input_length, int bias_length) - : params_(TuningContext(), Stream(), static_cast(input.ptr()), static_cast(bias.ptr()), - static_cast(output.ptr()), input_length, bias_length) {} - - bool IsSupported() { - return true; - } - - void Run() override { - ORT_THROW_IF_ERROR((internal::ElementwiseStaticSelection(¶ms_))); - } - - private: - using ParamsT = internal::ElementwiseParams; - ParamsT params_{}; -}; - -template -class ElementwiseTunable : public IKernelExplorer { - public: - ElementwiseTunable(DeviceArray& input, DeviceArray& bias, DeviceArray& output, int input_length, int bias_length) - : params_(TuningContext(), Stream(), static_cast(input.ptr()), static_cast(bias.ptr()), - static_cast(output.ptr()), input_length, bias_length) { - params_.TuningContext()->EnableTunableOpAndTuning(); - } - - void Run() override { - WithMaxTuningDurationMs max_duration(TuningContext(), 250); - ORT_THROW_IF_ERROR(op_(¶ms_)); - } - - bool IsSupported() { - return true; - } - - private: - using ParamsT = internal::ElementwiseParams; - ParamsT params_{}; - internal::ElementwiseTunableOp op_{}; -}; - -#define REGISTER_OP(registered_name, tpl, functor_name, dtype, threads_per_block, vec_size) \ - py::class_>( \ - m, #registered_name "_" #dtype "_" #threads_per_block "_" #vec_size) \ - .def(py::init()) \ - .def("SetRepeats", &tpl::SetRepeats) \ - .def("Profile", &tpl::Profile) \ - .def("Run", &tpl::Run) \ - .def("IsSupported", &tpl::IsSupported); - -#define REGISTER_OP_FOR_ALL_VEC_SIZE(registered_name, tpl, functor_name, dtype, threads_per_block) \ - REGISTER_OP(functor_name, tpl, functor_name, dtype, threads_per_block, 1) \ - REGISTER_OP(functor_name, tpl, functor_name, dtype, threads_per_block, 2) \ - REGISTER_OP(functor_name, tpl, functor_name, dtype, threads_per_block, 4) \ - REGISTER_OP(functor_name, tpl, functor_name, dtype, threads_per_block, 8) \ - REGISTER_OP(functor_name, tpl, functor_name, dtype, threads_per_block, 16) - -#define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(registered_name, tpl, functor_name, dtype) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(registered_name, tpl, functor_name, dtype, 64) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(registered_name, tpl, functor_name, dtype, 128) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(registered_name, tpl, functor_name, dtype, 192) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(registered_name, tpl, functor_name, dtype, 256) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(registered_name, tpl, functor_name, dtype, 320) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(registered_name, tpl, functor_name, dtype, 384) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(registered_name, tpl, functor_name, dtype, 448) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(registered_name, tpl, functor_name, dtype, 512) - -#define REGISTER_OP_TYPED(registered_name, tpl, functor_name, dtype) \ - py::class_>(m, registered_name "_" #dtype) \ - .def(py::init()) \ - .def("SetRepeats", &tpl::SetRepeats) \ - .def("Profile", &tpl::Profile) \ - .def("Run", &tpl::Run) \ - .def("IsSupported", &tpl::IsSupported); - -KE_REGISTER(m) { - REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK("FastGeLU", Elementwise, FastGeLU, half); - REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK("FastGeLU", Elementwise, FastGeLU, float); - - REGISTER_OP_TYPED("FastGeLUTunable", ElementwiseTunable, FastGeLU, half); - REGISTER_OP_TYPED("FastGeLUTunable", ElementwiseTunable, FastGeLU, float); - - REGISTER_OP_TYPED("FastGeLUStaticSelection", ElementwiseStaticSelection, FastGeLU, half); - REGISTER_OP_TYPED("FastGeLUStaticSelection", ElementwiseStaticSelection, FastGeLU, float); -} - -KE_REGISTER(m) { -REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK("GeLU", Elementwise, GeLU, half); -REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK("GeLU", Elementwise, GeLU, float); - -REGISTER_OP_TYPED("GeLUTunable", ElementwiseTunable, GeLU, half); -REGISTER_OP_TYPED("GeLUTunable", ElementwiseTunable, GeLU, float); - -REGISTER_OP_TYPED("GeLUStaticSelection", ElementwiseStaticSelection, GeLU, half); -REGISTER_OP_TYPED("GeLUStaticSelection", ElementwiseStaticSelection, GeLU, float); -} - -KE_REGISTER(m) { -REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK("ReLU", Elementwise, ReLU, half); -REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK("ReLU", Elementwise, ReLU, float); - -REGISTER_OP_TYPED("ReLUTunable", ElementwiseTunable, ReLU, half); -REGISTER_OP_TYPED("ReLUTunable", ElementwiseTunable, ReLU, float); - -REGISTER_OP_TYPED("ReLUStaticSelection", ElementwiseStaticSelection, ReLU, half); -REGISTER_OP_TYPED("ReLUStaticSelection", ElementwiseStaticSelection, ReLU, float); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm.cc b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm.cc deleted file mode 100644 index 540964a14912a..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm.cc +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include - -#include "core/providers/rocm/tunable/gemm_common.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -using BlasOp = onnxruntime::rocm::tunable::blas::BlasOp; - -namespace py = pybind11; - -namespace onnxruntime { - -KE_REGISTER(mod) { - auto blas_op = mod.def_submodule("blas_op"); - - py::enum_(blas_op, "BlasOp") - .value("N", BlasOp::N, "Passthrough") - .value("T", BlasOp::T, "Transpose") - .export_values(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu deleted file mode 100644 index 6c6bc147bd2a0..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ck.cu +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include -#include -#include -#include -#include - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "core/providers/rocm/tunable/gemm_ck.cuh" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -using namespace onnxruntime::rocm::tunable::blas; -using namespace onnxruntime::rocm::tunable::blas::internal; - -namespace py = pybind11; - -namespace onnxruntime { - -#ifdef USE_COMPOSABLE_KERNEL -template -class CKGemm : public IKernelExplorer { - public: - CKGemm(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, - DeviceArray& b, int64_t ldb, - double beta, - DeviceArray& c, int64_t ldc) - : params_{} { - ORT_ENFORCE(opa == OpA && opb == OpB); - - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - // rocblas handle is not used for ck - params_.handle = nullptr; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = static_cast(alpha); - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.beta = static_cast(beta); - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - - for (auto&& [type_string, op] : GetCKGemmTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } - for (auto&& [type_string, op] : GetCKStreamKGemmTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } - for (auto&& [type_string, op] : GetCKSplitKGemmTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } - ORT_ENFORCE(!ops_.empty()); - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - using ParamsT = GemmParams; - using OpT = Op; - ParamsT params_; - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; - -template -class CKStridedBatchedGemm : public IKernelExplorer { - public: - CKStridedBatchedGemm( - BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, int64_t stride_a, - DeviceArray& b, int64_t ldb, int64_t stride_b, - double beta, - DeviceArray& c, int64_t ldc, int64_t stride_c, - int64_t batch) - : params_{} { - ORT_ENFORCE(opa == OpA && opb == OpB); - - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - // rocblas handle is not used for ck - params_.handle = nullptr; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = static_cast(alpha); - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.stride_a = stride_a; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.stride_b = stride_b; - params_.beta = static_cast(beta); - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - params_.stride_c = stride_c; - params_.batch = batch; - - for (auto&& [type_string, op] : GetCKStridedBatchedGemmTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } - ORT_ENFORCE(!ops_.empty()); - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - using ParamsT = StridedBatchedGemmParams; - using OpT = Op; - ParamsT params_; - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; - -#define REGISTER_OP_COMMON(type, dtype, opa, opb, layout_string) \ - py::class_>(m, #type "_" #dtype "_" layout_string) \ - .def("SetRepeats", &type::SetRepeats) \ - .def("Profile", &type::Profile) \ - .def("Run", &type::Run) \ - .def("ListOps", &type::ListOps) \ - .def("SelectOp", &type::SelectOp) - -#define REGISTER_CKGEMM(dtype, opa, opb, layout_string) \ - REGISTER_OP_COMMON(CKGemm, dtype, opa, opb, layout_string) \ - .def(py::init()); - -#define REGISTER_CKGEMM_FOR_ALL_TRANSAB(dtype) \ - REGISTER_CKGEMM(dtype, BlasOp::N, BlasOp::N, "NN"); \ - REGISTER_CKGEMM(dtype, BlasOp::N, BlasOp::T, "NT"); \ - REGISTER_CKGEMM(dtype, BlasOp::T, BlasOp::N, "TN"); \ - REGISTER_CKGEMM(dtype, BlasOp::T, BlasOp::T, "TT"); - -#define REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, opa, opb, layout_string) \ - REGISTER_OP_COMMON(CKStridedBatchedGemm, dtype, opa, opb, layout_string) \ - .def(py::init()); - -#define REGISTER_CKSTRIDEDBATCHEDGEMM_FOR_ALL_TRANSAB(dtype) \ - REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, BlasOp::N, BlasOp::N, "NN"); \ - REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, BlasOp::N, BlasOp::T, "NT"); \ - REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, BlasOp::T, BlasOp::N, "TN"); \ - REGISTER_CKSTRIDEDBATCHEDGEMM(dtype, BlasOp::T, BlasOp::T, "TT"); - -KE_REGISTER(m) { - REGISTER_CKGEMM_FOR_ALL_TRANSAB(float); - REGISTER_CKGEMM_FOR_ALL_TRANSAB(half); - - REGISTER_CKSTRIDEDBATCHEDGEMM_FOR_ALL_TRANSAB(float); - REGISTER_CKSTRIDEDBATCHEDGEMM_FOR_ALL_TRANSAB(half); -} -#endif // USE_COMPOSABLE_KERNEL - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu deleted file mode 100644 index ec7083186b977..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include - -#include -#include -#include -#include -#include - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -using namespace onnxruntime::contrib::rocm::blas; -using namespace onnxruntime::contrib::rocm::blas::internal; - -namespace py = pybind11; - -namespace onnxruntime { - -#ifdef USE_COMPOSABLE_KERNEL -template -class CKGemmFastGelu : public IKernelExplorer { - public: - CKGemmFastGelu(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, - DeviceArray& b, int64_t ldb, - DeviceArray& bias, - double beta, - DeviceArray& c, int64_t ldc) - : params_{} { - ORT_ENFORCE(opa == OpA && opb == OpB); - - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - // rocblas handle is not used for ck - params_.handle = nullptr; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = static_cast(alpha); - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.bias = static_cast(bias.ptr()); - params_.beta = static_cast(beta); - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - - for (auto&& [type_string, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } - for (auto&& [type_string, op] : GetCKGemmFastGeluTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - using ParamsT = GemmFastGeluParams; - using OpT = Op; - ParamsT params_; - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; - -#define REGISTER_OP(type, opa, opb, layout_string) \ - py::class_>(m, "CKGemmFastGelu_" #type "_" layout_string) \ - .def(py::init()) \ - .def("SetRepeats", &CKGemmFastGelu::SetRepeats) \ - .def("Profile", &CKGemmFastGelu::Profile) \ - .def("Run", &CKGemmFastGelu::Run) \ - .def("ListOps", &CKGemmFastGelu::ListOps) \ - .def("SelectOp", &CKGemmFastGelu::SelectOp); - -#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ - REGISTER_OP(type, BlasOp::N, BlasOp::N, "NN"); \ - REGISTER_OP(type, BlasOp::N, BlasOp::T, "NT"); \ - REGISTER_OP(type, BlasOp::T, BlasOp::N, "TN"); \ - REGISTER_OP(type, BlasOp::T, BlasOp::T, "TT"); - -KE_REGISTER(m) { - REGISTER_OP_FOR_ALL_TRANSAB(float); - REGISTER_OP_FOR_ALL_TRANSAB(half); -} -#endif // USE_COMPOSABLE_KERNEL - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_hipblaslt.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_hipblaslt.cu deleted file mode 100644 index 4d8ecfc34219e..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_hipblaslt.cu +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include -#include - -#ifdef USE_HIPBLASLT -#include "core/providers/rocm/tunable/gemm_hipblaslt.h" -#endif - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "core/providers/rocm/rocm_common.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -namespace py = pybind11; - -namespace onnxruntime { - -#ifdef USE_HIPBLASLT - -using namespace rocm::tunable::blas::internal; - -template -class GemmFastGeluHipBlasLt : public IKernelExplorer { - public: - GemmFastGeluHipBlasLt(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, - DeviceArray& b, int64_t ldb, - DeviceArray& bias, - double beta, - DeviceArray& c, int64_t ldc) : params_{} { - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - // rocblas handle is not used for hipBLASLt - params_.handle = nullptr; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = static_cast(alpha); - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.bias = static_cast(bias.ptr()); - params_.beta = static_cast(beta); - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - - for (auto&& [type_string, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } - ORT_ENFORCE(!ops_.empty()); - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - using ParamsT = contrib::rocm::blas::GemmFastGeluParams; - using OpT = Op; - ParamsT params_; - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; - -#define REGISTER_OP(type, opa, opb, layout_string) \ - py::class_>(m, "GemmFastGeluHipBlasLt_" #type "_" layout_string) \ - .def(py::init()) \ - .def("SetRepeats", &GemmFastGeluHipBlasLt::SetRepeats) \ - .def("Profile", &GemmFastGeluHipBlasLt::Profile) \ - .def("Run", &GemmFastGeluHipBlasLt::Run) \ - .def("ListOps", &GemmFastGeluHipBlasLt::ListOps) \ - .def("SelectOp", &GemmFastGeluHipBlasLt::SelectOp); - -#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ - REGISTER_OP(type, BlasOp::N, BlasOp::N, "NN"); \ - REGISTER_OP(type, BlasOp::N, BlasOp::T, "NT"); \ - REGISTER_OP(type, BlasOp::T, BlasOp::N, "TN"); \ - REGISTER_OP(type, BlasOp::T, BlasOp::T, "TT"); - -KE_REGISTER(m) { - REGISTER_OP_FOR_ALL_TRANSAB(float); - REGISTER_OP_FOR_ALL_TRANSAB(half); -} -#endif // USE_HIPBLASLT - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu deleted file mode 100644 index 3f375c67acf85..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include -#include - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -using namespace onnxruntime::contrib::rocm::blas; -using namespace onnxruntime::contrib::rocm::blas::internal; - -namespace py = pybind11; - -namespace onnxruntime { -template -class GemmFastGeluTunable : public IKernelExplorer { - public: - GemmFastGeluTunable(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, - DeviceArray& b, int64_t ldb, - DeviceArray& bias, - double beta, - DeviceArray& c, int64_t ldc) : params_{} { - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.handle = rocblas_handle_; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = static_cast(alpha); - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.bias = static_cast(bias.ptr()); - params_.beta = static_cast(beta); - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - - params_.TuningContext()->EnableTunableOpAndTuning(); - } - - ~GemmFastGeluTunable() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - rocblas_handle_ = nullptr; - } - - void Run() override { - WithMaxTuningDurationMs max_duration(TuningContext(), 250); - ORT_THROW_IF_ERROR((op_(¶ms_))); - } - - std::vector ListOps() const { - return {"GemmFastGeluTunable"}; - } - - bool SelectOp(const std::string& name) { - return name == "GemmFastGeluTunable"; - } - - private: - using ParamsT = GemmFastGeluParams; - ParamsT params_{}; - rocblas_handle rocblas_handle_; - GemmFastGeluTunableOp op_{}; -}; - -#define REGISTER_OP(type, opa, opb, layout_string) \ - py::class_>(m, "GemmFastGeluTunable_" #type "_" layout_string) \ - .def(py::init()) \ - .def("SetRepeats", &GemmFastGeluTunable::SetRepeats) \ - .def("Profile", &GemmFastGeluTunable::Profile) \ - .def("Run", &GemmFastGeluTunable::Run) \ - .def("ListOps", &GemmFastGeluTunable::ListOps) \ - .def("SelectOp", &GemmFastGeluTunable::SelectOp); - -#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ - REGISTER_OP(type, BlasOp::N, BlasOp::N, "NN"); \ - REGISTER_OP(type, BlasOp::N, BlasOp::T, "NT"); \ - REGISTER_OP(type, BlasOp::T, BlasOp::N, "TN"); \ - REGISTER_OP(type, BlasOp::T, BlasOp::T, "TT"); - -KE_REGISTER(m) { - REGISTER_OP_FOR_ALL_TRANSAB(float); - REGISTER_OP_FOR_ALL_TRANSAB(half); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu deleted file mode 100644 index b39a3eab04a16..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include -#include - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -using namespace onnxruntime::contrib::rocm::blas; -using namespace onnxruntime::contrib::rocm::blas::internal; - -namespace py = pybind11; - -namespace onnxruntime { - -template -class GemmFastGeluUnfused : public IKernelExplorer { - public: - GemmFastGeluUnfused(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, - DeviceArray& b, int64_t ldb, - DeviceArray& bias, - double beta, - DeviceArray& c, int64_t ldc) : params_{} { - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.handle = rocblas_handle_; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = static_cast(alpha); - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.bias = static_cast(bias.ptr()); - params_.beta = static_cast(beta); - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - } - - ~GemmFastGeluUnfused() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - rocblas_handle_ = nullptr; - } - - void Run() override { - ORT_THROW_IF_ERROR((contrib::rocm::blas::internal::GemmFastGeluUnfused(¶ms_))); - } - - std::vector ListOps() const { - return {"GemmFastGeluUnfused"}; - } - - bool SelectOp(const std::string& name) { - Status status = contrib::rocm::blas::internal::GemmFastGeluUnfused(¶ms_); - return status.IsOK() && name == "GemmFastGeluUnfused"; - } - - private: - using ParamsT = GemmFastGeluParams; - ParamsT params_{}; - rocblas_handle rocblas_handle_; -}; - -#define REGISTER_OP(type) \ - py::class_>(m, "GemmFastGeluUnfused_" #type) \ - .def(py::init()) \ - .def("SetRepeats", &GemmFastGeluUnfused::SetRepeats) \ - .def("Run", &GemmFastGeluUnfused::Run) \ - .def("Profile", &GemmFastGeluUnfused::Profile) \ - .def("ListOps", &GemmFastGeluUnfused::ListOps) \ - .def("SelectOp", &GemmFastGeluUnfused::SelectOp); - -KE_REGISTER(m) { - REGISTER_OP(float) - REGISTER_OP(half) -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu deleted file mode 100644 index 2d78f390af84a..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu +++ /dev/null @@ -1,208 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include -#include -#include -#include -#include - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -using namespace onnxruntime::rocm::tunable::blas; - -namespace py = pybind11; - -namespace onnxruntime { - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) -template -class GemmFloat8CK : public IKernelExplorer { - public: - GemmFloat8CK(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - float alpha, - DeviceArray& a, int64_t lda, DeviceArray& scale_a, - DeviceArray& b, int64_t ldb, DeviceArray& scale_b, - float beta, - DeviceArray& c, int64_t ldc, DeviceArray& scale_c) { - ORT_ENFORCE(opa == OpA && opb == OpB); - - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - // rocblas handle is not used for ck - params_.handle = nullptr; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - - params_.a = static_cast(a.ptr()); - params_.lda = lda; - if constexpr (std::is_same_v || std::is_same_v) { - params_.scale_a = alpha; - params_.scale_a_dev = static_cast(scale_a.ptr()); - } - - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - if constexpr (std::is_same_v || std::is_same_v) { - params_.scale_b = alpha; - params_.scale_b_dev = static_cast(scale_b.ptr()); - } - - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - if constexpr (std::is_same_v || std::is_same_v) { - ORT_ENFORCE(false, "Not implemented"); - params_.scale_c = beta; - params_.scale_c_dev = static_cast(scale_c.ptr()); - } - - for (auto&& [type_string, op] : GetCKF8SplitKGemmTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } - ORT_ENFORCE(!ops_.empty()); - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - using ParamsT = GemmFloat8Params; - using OpT = Op; - ParamsT params_{}; - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; - -template -class GemmFloat8Tunable : public IKernelExplorer { - public: - GemmFloat8Tunable(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - float alpha, - DeviceArray& a, int64_t lda, DeviceArray& scale_a, - DeviceArray& b, int64_t ldb, DeviceArray& scale_b, - float beta, - DeviceArray& c, int64_t ldc, DeviceArray& scale_c) { - ORT_ENFORCE(opa == OpA && opb == OpB); - - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - // rocblas handle is not used for ck - params_.handle = nullptr; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - - params_.a = static_cast(a.ptr()); - params_.lda = lda; - if constexpr (std::is_same_v || std::is_same_v) { - params_.scale_a = alpha; - params_.scale_a_dev = static_cast(scale_a.ptr()); - } - - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - if constexpr (std::is_same_v || std::is_same_v) { - params_.scale_b = alpha; - params_.scale_b_dev = static_cast(scale_b.ptr()); - } - - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - if constexpr (std::is_same_v || std::is_same_v) { - ORT_ENFORCE(false, "Not implemented"); - params_.scale_c = beta; - params_.scale_c_dev = static_cast(scale_c.ptr()); - } - - params_.TuningContext()->EnableTunableOpAndTuning(); - } - - void Run() override { - ORT_THROW_IF_ERROR(op_(¶ms_)); - } - - std::vector ListOps() const { - return {"Tunable"}; - } - - bool SelectOp(const std::string& name) { - return name == "Tunable"; - } - - private: - using ParamsT = GemmFloat8Params; - using OpT = GemmFloat8TunableOp; - ParamsT params_{}; - OpT op_; -}; - -#define REGISTER_GEMM_FLOAT8(registered_name, tpl, dta, dtb, dtc, opa, opb) \ - py::class_>(m, registered_name) \ - .def("SetRepeats", &tpl::SetRepeats) \ - .def("Profile", &tpl::Profile) \ - .def("Run", &tpl::Run) \ - .def("ListOps", &tpl::ListOps) \ - .def("SelectOp", &tpl::SelectOp) \ - .def(py::init()); - -KE_REGISTER(m) { - using BlasOp = rocm::tunable::blas::BlasOp; - REGISTER_GEMM_FLOAT8("GemmFloat8CK_fp8e4m3fn_half_half_NN", GemmFloat8CK, Float8E4M3FN, half, half, BlasOp::N, BlasOp::N); - REGISTER_GEMM_FLOAT8("GemmFloat8CK_half_fp8e4m3fn_half_NN", GemmFloat8CK, half, Float8E4M3FN, half, BlasOp::N, BlasOp::N); - REGISTER_GEMM_FLOAT8("GemmFloat8CK_fp8e4m3fnuz_half_half_NN", GemmFloat8CK, Float8E4M3FNUZ, half, half, BlasOp::N, BlasOp::N); - REGISTER_GEMM_FLOAT8("GemmFloat8CK_half_fp8e4m3fnuz_half_NN", GemmFloat8CK, half, Float8E4M3FNUZ, half, BlasOp::N, BlasOp::N); - - REGISTER_GEMM_FLOAT8("GemmFloat8CK_half_fp8e4m3fn_half_NT", GemmFloat8CK, half, Float8E4M3FN, half, BlasOp::N, BlasOp::T); - REGISTER_GEMM_FLOAT8("GemmFloat8CK_half_fp8e4m3fnuz_half_NT", GemmFloat8CK, half, Float8E4M3FNUZ, half, BlasOp::N, BlasOp::T); -} - -KE_REGISTER(m) { - using BlasOp = rocm::tunable::blas::BlasOp; - REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_fp8e4m3fn_half_half_NN", GemmFloat8Tunable, Float8E4M3FN, half, half, BlasOp::N, BlasOp::N); - REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_half_fp8e4m3fn_half_NN", GemmFloat8Tunable, half, Float8E4M3FN, half, BlasOp::N, BlasOp::N); - REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_fp8e4m3fnuz_half_half_NN", GemmFloat8Tunable, Float8E4M3FNUZ, half, half, BlasOp::N, BlasOp::N); - REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_half_fp8e4m3fnuz_half_NN", GemmFloat8Tunable, half, Float8E4M3FNUZ, half, BlasOp::N, BlasOp::N); - - REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_half_fp8e4m3fn_half_NT", GemmFloat8Tunable, half, Float8E4M3FN, half, BlasOp::N, BlasOp::T); - REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_half_fp8e4m3fnuz_half_NT", GemmFloat8Tunable, half, Float8E4M3FNUZ, half, BlasOp::N, BlasOp::T); -} -#endif - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_hipblaslt.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_hipblaslt.cu deleted file mode 100644 index c0658dff193ae..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_hipblaslt.cu +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include -#include - -#ifdef USE_HIPBLASLT -#include "core/providers/rocm/tunable/gemm_hipblaslt.h" -#endif - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -using namespace onnxruntime::rocm::tunable::blas; - -namespace py = pybind11; - -namespace onnxruntime { - -#ifdef USE_HIPBLASLT - -using namespace rocm::tunable::blas::internal; - -template -class GemmHipBlasLt : public IKernelExplorer { - public: - GemmHipBlasLt(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, - DeviceArray& b, int64_t ldb, - double beta, - DeviceArray& c, int64_t ldc) - : params_{} { - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - // rocblas handle is not used for hipBLASLt - params_.handle = nullptr; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = static_cast(alpha); - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.beta = static_cast(beta); - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - - for (auto&& [type_string, op] : GetHipBlasLtGemmTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } - ORT_ENFORCE(!ops_.empty()); - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - using ParamsT = GemmParams; - using OpT = Op; - ParamsT params_; - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; - -template -class StridedBatchedGemmHipBlasLt : public IKernelExplorer { - public: - StridedBatchedGemmHipBlasLt( - BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, int64_t stride_a, - DeviceArray& b, int64_t ldb, int64_t stride_b, - double beta, - DeviceArray& c, int64_t ldc, int64_t stride_c, - int64_t batch) - : params_{} { - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - // rocblas handle is not used for hipBLASLt - params_.handle = nullptr; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = static_cast(alpha); - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.stride_a = stride_a; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.stride_b = stride_b; - params_.beta = static_cast(beta); - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - params_.stride_c = stride_c; - params_.batch = batch; - - for (auto&& [type_string, op] : GetHipBlasLtStridedBatchedGemmTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } - ORT_ENFORCE(!ops_.empty()); - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - using ParamsT = StridedBatchedGemmParams; - using OpT = Op; - ParamsT params_; - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; - -#define REGISTER_OP_COMMON(type, dtype, opa, opb, layout_string) \ - py::class_>(m, #type "_" #dtype "_" layout_string) \ - .def("SetRepeats", &type::SetRepeats) \ - .def("Profile", &type::Profile) \ - .def("Run", &type::Run) \ - .def("ListOps", &type::ListOps) \ - .def("SelectOp", &type::SelectOp) - -#define REGISTER_GEMM_HIPBLASLT(dtype, opa, opb, layout_string) \ - REGISTER_OP_COMMON(GemmHipBlasLt, dtype, opa, opb, layout_string) \ - .def(py::init()); - -#define REGISTER_GEMM_HIPBLASLT_FOR_ALL_TRANSAB(dtype) \ - REGISTER_GEMM_HIPBLASLT(dtype, BlasOp::N, BlasOp::N, "NN"); \ - REGISTER_GEMM_HIPBLASLT(dtype, BlasOp::N, BlasOp::T, "NT"); \ - REGISTER_GEMM_HIPBLASLT(dtype, BlasOp::T, BlasOp::N, "TN"); \ - REGISTER_GEMM_HIPBLASLT(dtype, BlasOp::T, BlasOp::T, "TT"); - -#define REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, opa, opb, layout_string) \ - REGISTER_OP_COMMON(StridedBatchedGemmHipBlasLt, dtype, opa, opb, layout_string) \ - .def(py::init()); - -#define REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT_FOR_ALL_TRANSAB(dtype) \ - REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, BlasOp::N, BlasOp::N, "NN"); \ - REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, BlasOp::N, BlasOp::T, "NT"); \ - REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, BlasOp::T, BlasOp::N, "TN"); \ - REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, BlasOp::T, BlasOp::T, "TT"); - -KE_REGISTER(m) { - REGISTER_GEMM_HIPBLASLT_FOR_ALL_TRANSAB(float); - REGISTER_GEMM_HIPBLASLT_FOR_ALL_TRANSAB(half); - - REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT_FOR_ALL_TRANSAB(float); - REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT_FOR_ALL_TRANSAB(half); -} -#endif // USE_HIPBLASLT - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ke.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ke.h deleted file mode 100644 index 7b20732d2c9a3..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_ke.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -namespace onnxruntime { - -template -class IBatchedGemmKernelExplorer : public IKernelExplorer { - protected: - void CopyAsBsCsPointersToDevice(const std::vector& as, - const std::vector& bs, - const std::vector& cs, - int64_t batch) { - ORT_ENFORCE(as.size() == batch); - ORT_ENFORCE(bs.size() == batch); - ORT_ENFORCE(cs.size() == batch); - CopyPointersToDevice(as, dev_as_); - CopyPointersToDevice(bs, dev_bs_); - CopyPointersToDevice(cs, dev_cs_); - } - - static void CopyPointersToDevice(const std::vector& src, std::shared_ptr& dst) { - // convert pointers in vector to continuous for copying - std::vector tmp; - auto cvt_to_raw_ptr = [](const DeviceArray& x) { return static_cast(x.ptr()); }; - std::transform(src.cbegin(), src.cend(), std::back_inserter(tmp), cvt_to_raw_ptr); - - // create buffer for pointers - T** ptrs; - HIP_CALL_THROW(hipMalloc(&ptrs, src.size() * sizeof(T*))); - dst.reset(ptrs, [](void* addr) { HIP_CALL_THROW(hipFree(addr)); }); - - // copy host pointers to buffer - HIP_CALL_THROW(hipMemcpy(ptrs, tmp.data(), src.size() * sizeof(T*), hipMemcpyHostToDevice)); - } - - std::shared_ptr dev_as_; - std::shared_ptr dev_bs_; - std::shared_ptr dev_cs_; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.cc b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.cc deleted file mode 100644 index 8c3aceb3f741a..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.cc +++ /dev/null @@ -1,311 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include - -#include -#include - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "core/providers/rocm/tunable/gemm_rocblas.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" -#include "python/tools/kernel_explorer/kernels/rocm/gemm_ke.h" - -using namespace onnxruntime::rocm::tunable::blas; -using namespace onnxruntime::rocm::tunable::blas::internal; - -namespace py = pybind11; - -namespace onnxruntime { - -template -class RocBlasGemm : public IKernelExplorer { - public: - RocBlasGemm(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, - DeviceArray& b, int64_t ldb, - double beta, - DeviceArray& c, int64_t ldc) - : params_{} { - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.handle = rocblas_handle_; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = alpha; - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.beta = beta; - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - - type_strings_.emplace_back("RocBlasGemmDefault"); - ops_.emplace_back([](auto* params) { return RocBlasGemmOp(params); }); - -#ifdef USE_ROCBLAS_EXTENSION_API - for (auto&& [type_string, op] : GetRocBlasGemmTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } -#endif - } - - ~RocBlasGemm() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - rocblas_handle_ = nullptr; - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - rocblas_handle rocblas_handle_; - - using ParamsT = GemmParams; - using OpT = Op; - - ParamsT params_{}; - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; - -template -class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer { - public: - RocBlasBatchedGemm(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - std::vector& as, int64_t lda, - std::vector& bs, int64_t ldb, - double beta, - std::vector& cs, int64_t ldc, - int64_t batch) - : params_{} { - this->CopyAsBsCsPointersToDevice(as, bs, cs, batch); - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - params_.tuning_ctx = this->TuningContext(); - params_.stream = this->Stream(); - params_.handle = rocblas_handle_; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = alpha; - params_.as = const_cast(this->dev_as_.get()); - params_.lda = lda; - params_.bs = const_cast(this->dev_bs_.get()); - params_.ldb = ldb; - params_.beta = beta; - params_.cs = this->dev_cs_.get(); - params_.ldc = ldc; - params_.batch = batch; - - type_strings_.emplace_back("RocBlasBatchedGemmDefault"); - ops_.emplace_back([](auto* params) { return RocBlasBatchedGemmOp(params); }); - -#ifdef USE_ROCBLAS_EXTENSION_API - for (auto&& [type_string, op] : GetRocBlasBatchedGemmTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } -#endif - } - - ~RocBlasBatchedGemm() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - rocblas_handle_ = nullptr; - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - rocblas_handle rocblas_handle_; - - using ParamsT = BatchedGemmParams; - using OpT = Op; - - ParamsT params_{}; - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; - -template -class RocBlasStridedBatchedGemm : public IKernelExplorer { - public: - RocBlasStridedBatchedGemm(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, int64_t stride_a, - DeviceArray& b, int64_t ldb, int64_t stride_b, - double beta, - DeviceArray& c, int64_t ldc, int64_t stride_c, - int64_t batch) - : params_{} { - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.handle = rocblas_handle_; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = alpha; - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.stride_a = stride_a; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.stride_b = stride_b; - params_.beta = beta; - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - params_.stride_c = stride_c; - params_.batch = batch; - - type_strings_.emplace_back("RocBlasStridedBatchedGemmDefault"); - ops_.emplace_back([](auto* params) { return RocBlasStridedBatchedGemmOp(params); }); - -#ifdef USE_ROCBLAS_EXTENSION_API - for (auto&& [type_string, op] : GetRocBlasStridedBatchedGemmTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } -#endif - } - - ~RocBlasStridedBatchedGemm() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - rocblas_handle_ = nullptr; - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - rocblas_handle rocblas_handle_; - - using ParamsT = StridedBatchedGemmParams; - using OpT = Op; - - ParamsT params_{}; - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; - -#define REGISTER_OP_COMMON(type, dtype) \ - py::class_>(mod, #type "_" #dtype) \ - .def("SetRepeats", &type::SetRepeats) \ - .def("Profile", &type::Profile) \ - .def("Run", &type::Run) \ - .def("ListOps", &type::ListOps) \ - .def("SelectOp", &type::SelectOp) - -#define REGISTER_GEMM(dtype) \ - REGISTER_OP_COMMON(RocBlasGemm, dtype) \ - .def(py::init()) - -#define REGISTER_BATCHED_GEMM(dtype) \ - REGISTER_OP_COMMON(RocBlasBatchedGemm, dtype) \ - .def(py::init&, int64_t, \ - std::vector&, int64_t, \ - double, \ - std::vector&, int64_t, \ - int64_t>()) - -#define REGISTER_STRIDED_BATCHED_GEMM(dtype) \ - REGISTER_OP_COMMON(RocBlasStridedBatchedGemm, dtype) \ - .def(py::init()) - -KE_REGISTER(mod) { - REGISTER_GEMM(float); - REGISTER_GEMM(half); - - REGISTER_BATCHED_GEMM(float); - REGISTER_BATCHED_GEMM(half); - - REGISTER_STRIDED_BATCHED_GEMM(float); - REGISTER_STRIDED_BATCHED_GEMM(half); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu deleted file mode 100644 index 7068fc8fd0ebc..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu +++ /dev/null @@ -1,369 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "pybind11/stl.h" - -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -#include - -namespace py = pybind11; - -using namespace onnxruntime::contrib::rocm; - -namespace onnxruntime { - -template -class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer { - public: - IGemmSoftmaxGemmPermuteKernelExplorer( - int64_t batch, - int64_t seqlen, - int64_t total_seqlen, - std::optional max_seqlen, - int64_t num_heads, - int64_t head_size, - int64_t mask_dim, - double scale, - bool causal, - contrib::AttentionQkvFormat qkv_format, - DeviceArray& Q, - std::optional& K, - std::optional& V, - std::optional& attn_bias, - std::optional& attn_mask, - DeviceArray& out) { - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - - attn_.batch_size = batch; - attn_.sequence_length = seqlen; - // NOTE: This test wrapper does not support past present concat, then past_sequence_length = 0 always holds. - // Thus, total_sequence_length = past_sequence_length + kv_sequence_length further implies - // total_sequence_length == kv_sequence_length - attn_.kv_sequence_length = total_seqlen; - attn_.past_sequence_length = 0; - attn_.total_sequence_length = total_seqlen; - attn_.max_sequence_length = 0; - attn_.hidden_size = num_heads * head_size; - attn_.head_size = head_size; - attn_.v_hidden_size = attn_.hidden_size; // Q,K,V hidden size must agree now - attn_.v_head_size = attn_.head_size; // Q,K,V hidden size must agree now - attn_.num_heads = num_heads; - attn_.is_unidirectional = causal; - attn_.past_present_share_buffer = false; - attn_.do_rotary = false; - attn_.mask_filter_value = -10000.0f; - attn_.scale = scale; - if (mask_dim == 0) { - attn_.mask_type = contrib::MASK_NONE; - } else if (mask_dim == 2) { - attn_.mask_type = contrib::MASK_2D_KEY_PADDING; - } else if (mask_dim == 3) { - attn_.mask_type = contrib::MASK_3D_ATTENTION; - } else if (mask_dim == 4) { - attn_.mask_type = contrib::MASK_4D_MEGATRON; - } else { - ORT_ENFORCE(false, "mask type not supported"); - } - attn_.qkv_format = qkv_format; - switch (qkv_format) { - case contrib::Q_K_V_BNSH: - case contrib::Q_K_V_BSNH: - attn_.mode = contrib::rocm::QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE; - break; - case contrib::Q_KV_BSNH_BSN2H: - attn_.mode = contrib::rocm::BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE; - break; - case contrib::QKV_BSN3H: - attn_.mode = contrib::rocm::BLN3H_NONE_NONE_NONE_NONE_NONE_NONE; - break; - default: - ORT_NOT_IMPLEMENTED("qkv_format ", qkv_format, " is not implemented"); - } - - device_prop = GetEp()->GetDeviceProp(); - - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.handle = rocblas_handle_; - params_.attention = &attn_; - params_.device_prop = &device_prop; - params_.scale = scale; - - std::tie(params_.q_buffer, params_.k_buffer, params_.v_buffer) = ConvertToOffsetedBufferViews( - &attn_, Q.ptr(), K.has_value() ? K->ptr() : nullptr, V.has_value() ? V->ptr() : nullptr); - - if (attn_bias.has_value()) { - params_.bias_buffer = reinterpret_cast(attn_bias->ptr()); - } - if (attn_mask.has_value()) { - params_.mask_index_buffer = reinterpret_cast(attn_mask->ptr()); - if (mask_dim == 2) { - params_.mask_index_dims = {batch, total_seqlen}; - } else if (mask_dim == 3) { - params_.mask_index_dims = {batch, seqlen, total_seqlen}; - } else if (mask_dim == 4) { - ORT_ENFORCE(max_seqlen.has_value()); - attn_.max_sequence_length = max_seqlen.value(); - ORT_ENFORCE(attn_.max_sequence_length >= seqlen); - attn_.past_sequence_length = attn_.max_sequence_length - seqlen; - params_.mask_index_dims = {batch, 1, attn_.max_sequence_length, attn_.max_sequence_length}; - } - } - params_.out_buffer = reinterpret_cast(out.ptr()); - } - - ~IGemmSoftmaxGemmPermuteKernelExplorer() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - } - - void SetWorkspace(size_t num_bytes) { - void* ptr; - HIP_CALL_THROW(hipMalloc(&ptr, num_bytes)); - workspace_.reset(ptr, [](void* ptr) { HIP_CALL_THROW(hipFree(ptr)); }); - params_.workspace_buffer = reinterpret_cast(workspace_.get()); - } - - protected: - using ParamsT = contrib::rocm::GemmSoftmaxGemmPermuteParams; - rocblas_handle rocblas_handle_; - hipDeviceProp_t device_prop; - contrib::rocm::RocmAttentionParameters attn_; - ParamsT params_; - std::shared_ptr workspace_; -}; - -// The pipeline composed from rocblas api calls and kernel launches. -template -class GemmSoftmaxGemmPermuteGeneric : public IGemmSoftmaxGemmPermuteKernelExplorer { - public: - GemmSoftmaxGemmPermuteGeneric( - int64_t batch, - int64_t seqlen, - int64_t total_seqlen, - std::optional max_seqlen, - int64_t num_heads, - int64_t head_size, - int64_t mask_dim, - double scale, - bool causal, - contrib::AttentionQkvFormat qkv_format, - DeviceArray& Q, - std::optional& K, - std::optional& V, - std::optional& attn_bias, - std::optional& attn_mask, - DeviceArray& out) - : IGemmSoftmaxGemmPermuteKernelExplorer(batch, seqlen, total_seqlen, max_seqlen, - num_heads, head_size, mask_dim, scale, causal, qkv_format, - Q, K, V, attn_bias, attn_mask, out) { - this->SetWorkspace(GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(&this->attn_)); - } - - std::vector ListOps() const { - return {"Generic"}; - } - - bool SelectOp(const std::string&) { - return true; - } - - void Run() override { - ORT_THROW_IF_ERROR(GemmSoftmaxGemmPermuteGenericPipeline::Run( - &this->params_, /*use_persistent_softmax=*/false)); - } -}; - -template -class GemmSoftmaxGemmPermuteGenericNestedTunable : public GemmSoftmaxGemmPermuteGeneric { - public: - GemmSoftmaxGemmPermuteGenericNestedTunable( - int64_t batch, - int64_t seqlen, - int64_t total_seqlen, - std::optional max_seqlen, - int64_t num_heads, - int64_t head_size, - int64_t mask_dim, - double scale, - bool causal, - contrib::AttentionQkvFormat qkv_format, - DeviceArray& Q, - std::optional& K, - std::optional& V, - std::optional& attn_bias, - std::optional& attn_mask, - DeviceArray& out) - : GemmSoftmaxGemmPermuteGeneric(batch, seqlen, total_seqlen, max_seqlen, - num_heads, head_size, mask_dim, scale, causal, qkv_format, - Q, K, V, attn_bias, attn_mask, out) { - this->params_.TuningContext()->EnableTunableOpAndTuning(); - } -}; - -#ifdef USE_COMPOSABLE_KERNEL -template -class GemmSoftmaxGemmPermuteCK : public IGemmSoftmaxGemmPermuteKernelExplorer { - public: - GemmSoftmaxGemmPermuteCK( - int64_t batch, - int64_t seqlen, - int64_t total_seqlen, - std::optional max_seqlen, - int64_t num_heads, - int64_t head_size, - int64_t mask_dim, - double scale, - bool causal, - contrib::AttentionQkvFormat qkv_format, - DeviceArray& Q, - std::optional& K, - std::optional& V, - std::optional& attn_bias, - std::optional& attn_mask, - DeviceArray& out) - : IGemmSoftmaxGemmPermuteKernelExplorer(batch, seqlen, total_seqlen, max_seqlen, - num_heads, head_size, mask_dim, scale, causal, qkv_format, - Q, K, V, attn_bias, attn_mask, out) { - this->SetWorkspace(GemmSoftmaxGemmPermuteTunableOp::GetWorkspaceNumBytes(&this->attn_)); - - for (auto&& [ts, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - type_strings_.emplace_back(std::move(ts)); - ops_.emplace_back(std::move(op)); - } - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i].IsSupported(&this->params_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](&this->params_)); - } - - private: - using ParamsT = typename IGemmSoftmaxGemmPermuteKernelExplorer::ParamsT; - using OpT = Op; - - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; -#endif // USE_COMPOSABLE_KERNEL - -// The pipeline composed from rocblas api calls and kernel launches. -template -class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplorer { - public: - GemmSoftmaxGemmPermuteTunable( - int64_t batch, - int64_t seqlen, - int64_t total_seqlen, - std::optional max_seqlen, - int64_t num_heads, - int64_t head_size, - int64_t mask_dim, - double scale, - bool causal, - contrib::AttentionQkvFormat qkv_format, - DeviceArray& Q, - std::optional& K, - std::optional& V, - std::optional& attn_bias, - std::optional& attn_mask, - DeviceArray& out) - : IGemmSoftmaxGemmPermuteKernelExplorer(batch, seqlen, total_seqlen, max_seqlen, - num_heads, head_size, mask_dim, scale, causal, qkv_format, - Q, K, V, attn_bias, attn_mask, out) { - this->SetWorkspace(std::max( - GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(&this->attn_), - GemmSoftmaxGemmPermuteTunableOp::GetWorkspaceNumBytes(&this->attn_))); - - this->params_.TuningContext()->EnableTunableOpAndTuning(); - } - - std::vector ListOps() const { - return {"Tunable"}; - } - - bool SelectOp(const std::string&) { - return true; - } - - void Run() override { - ORT_THROW_IF_ERROR(op_(&this->params_)); - } - - // NOTE: this op is expensive to construct - GemmSoftmaxGemmPermuteTunableOp op_{}; -}; - -#define REGISTER_COMMON(name, type, ...) \ - py::class_>(m, name) \ - .def(py::init, int64_t, int64_t, int64_t, \ - float, bool, contrib::AttentionQkvFormat, \ - DeviceArray&, \ - std::optional&, \ - std::optional&, \ - std::optional&, \ - std::optional&, \ - DeviceArray&>()) \ - .def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \ - .def("Run", &type<__VA_ARGS__>::Run) \ - .def("Profile", &type<__VA_ARGS__>::Profile) \ - .def("ListOps", &type<__VA_ARGS__>::ListOps) \ - .def("SelectOp", &type<__VA_ARGS__>::SelectOp); - -#define REGISTER_GENERIC(dtype) \ - REGISTER_COMMON("GemmSoftmaxGemmPermuteGeneric_" #dtype, GemmSoftmaxGemmPermuteGeneric, dtype) - -#define REGISTER_GENERIC_NESTEDTUNABLE(dtype) \ - REGISTER_COMMON("GemmSoftmaxGemmPermuteGenericNestedTunable_" #dtype, GemmSoftmaxGemmPermuteGenericNestedTunable, dtype) - -#define REGISTER_CK(dtype, biased, masked, mask_bias_suffix) \ - REGISTER_COMMON( \ - "GemmSoftmaxGemmPermuteCK" mask_bias_suffix "_" #dtype, GemmSoftmaxGemmPermuteCK, dtype, biased, masked) - -#define REGISTER_TUNABLE(dtype) \ - REGISTER_COMMON("GemmSoftmaxGemmPermuteTunable_" #dtype, GemmSoftmaxGemmPermuteTunable, dtype) - -KE_REGISTER(m) { - auto qkv_format = m.def_submodule("qkv_format"); - py::enum_(qkv_format, "qkv_format") - .value("Q_K_V_BNSH", contrib::AttentionQkvFormat::Q_K_V_BNSH, "") - .value("Q_K_V_BSNH", contrib::AttentionQkvFormat::Q_K_V_BSNH, "") - .value("QKV_BSN3H", contrib::AttentionQkvFormat::QKV_BSN3H, "") - .value("Q_KV_BSNH_BSN2H", contrib::AttentionQkvFormat::Q_KV_BSNH_BSN2H, "") - .export_values(); - - REGISTER_GENERIC(half); - REGISTER_GENERIC(float); - REGISTER_GENERIC_NESTEDTUNABLE(half); - REGISTER_GENERIC_NESTEDTUNABLE(float); - -#ifdef USE_COMPOSABLE_KERNEL - REGISTER_CK(half, false, false, ""); - REGISTER_CK(half, true, false, "Biased"); - REGISTER_CK(half, false, true, "Masked"); - REGISTER_CK(half, true, true, "BiasedMasked"); -#endif - - REGISTER_TUNABLE(half); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu deleted file mode 100644 index e1d9b5de20e00..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include -#include -#include - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "core/providers/rocm/tunable/gemm_tunable.cuh" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" -#include "python/tools/kernel_explorer/kernels/rocm/gemm_ke.h" - -using namespace onnxruntime::rocm::tunable::blas; -using namespace onnxruntime::rocm::tunable::blas::internal; - -namespace onnxruntime { - -template -class GemmTunable : public IKernelExplorer { - public: - GemmTunable(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, - DeviceArray& b, int64_t ldb, - double beta, - DeviceArray& c, int64_t ldc) { - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.handle = rocblas_handle_; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = static_cast(alpha); - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.beta = static_cast(beta); - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - - params_.TuningContext()->EnableTunableOpAndTuning(); - } - - ~GemmTunable() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - rocblas_handle_ = nullptr; - } - - void Run() override { - WithMaxTuningDurationMs max_duration(TuningContext(), 250); - ORT_THROW_IF_ERROR(op_(¶ms_)); - } - - std::vector ListOps() const { - return {"Tunable"}; - } - - bool SelectOp(const std::string& name) { - return name == "Tunable"; - } - - private: - using ParamsT = GemmParams; - ParamsT params_; - - // tunable is stateful, store it as an instance - GemmTunableOp op_{}; - rocblas_handle rocblas_handle_; -}; - -template -class BatchedGemmTunable : public IBatchedGemmKernelExplorer { - public: - BatchedGemmTunable(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - std::vector& as, int64_t lda, - std::vector& bs, int64_t ldb, - double beta, - std::vector& cs, int64_t ldc, - int64_t batch) { - this->CopyAsBsCsPointersToDevice(as, bs, cs, batch); - - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - params_.tuning_ctx = this->TuningContext(); - params_.stream = this->Stream(); - params_.handle = rocblas_handle_; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = static_cast(alpha); - params_.as = const_cast(this->dev_as_.get()); - params_.lda = lda; - params_.bs = const_cast(this->dev_bs_.get()); - params_.ldb = ldb; - params_.beta = static_cast(beta); - params_.cs = this->dev_cs_.get(); - params_.ldc = ldc; - params_.batch = batch; - - params_.TuningContext()->EnableTunableOpAndTuning(); - } - - ~BatchedGemmTunable() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - rocblas_handle_ = nullptr; - } - - void Run() override { - WithMaxTuningDurationMs max_duration(params_.TuningContext(), 250); - ORT_THROW_IF_ERROR(op_(¶ms_)); - } - - std::vector ListOps() const { - return {"Tunable"}; - } - - bool SelectOp(const std::string& name) { - return name == "Tunable"; - } - - private: - using ParamsT = BatchedGemmParams; - ParamsT params_; - - // tunable is stateful, store it as an instance - BatchedGemmTunableOp op_{}; - rocblas_handle rocblas_handle_; -}; - -template -class StridedBatchedGemmTunable : public IKernelExplorer { - public: - StridedBatchedGemmTunable(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, int64_t stride_a, - DeviceArray& b, int64_t ldb, int64_t stride_b, - double beta, - DeviceArray& c, int64_t ldc, int64_t stride_c, - int64_t batch) { - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.handle = rocblas_handle_; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = static_cast(alpha); - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.stride_a = stride_a; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.stride_b = stride_b; - params_.beta = static_cast(beta); - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - params_.stride_c = stride_c; - params_.batch = batch; - - params_.TuningContext()->EnableTunableOpAndTuning(); - } - - ~StridedBatchedGemmTunable() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - rocblas_handle_ = nullptr; - } - - void Run() override { - WithMaxTuningDurationMs max_duration(params_.TuningContext(), 250); - ORT_THROW_IF_ERROR(op_(¶ms_)); - } - - std::vector ListOps() const { - return {"Tunable"}; - } - - bool SelectOp(const std::string& name) { - return name == "Tunable"; - } - - private: - using ParamsT = StridedBatchedGemmParams; - ParamsT params_; - - // tunable is stateful, store it as an instance - StridedBatchedGemmTunableOp op_{}; - rocblas_handle rocblas_handle_; -}; - -#define REGISTER_OP_COMMON(type, dtype, opa, opb, layout_string) \ - py::class_>(m, #type "_" #dtype "_" layout_string) \ - .def("SetRepeats", &type::SetRepeats) \ - .def("Profile", &type::Profile) \ - .def("Run", &type::Run) \ - .def("ListOps", &type::ListOps) \ - .def("SelectOp", &type::SelectOp) - -#define REGISTER_GEMM(dtype, opa, opb, layout_string) \ - REGISTER_OP_COMMON(GemmTunable, dtype, opa, opb, layout_string) \ - .def(py::init()) - -#define REGISTER_GEMM_FOR_ALL_TRANSAB(dtype) \ - REGISTER_GEMM(dtype, BlasOp::N, BlasOp::N, "NN"); \ - REGISTER_GEMM(dtype, BlasOp::N, BlasOp::T, "NT"); \ - REGISTER_GEMM(dtype, BlasOp::T, BlasOp::N, "TN"); \ - REGISTER_GEMM(dtype, BlasOp::T, BlasOp::T, "TT"); - -#define REGISTER_BATCHED_GEMM(dtype, opa, opb, layout_string) \ - REGISTER_OP_COMMON(BatchedGemmTunable, dtype, opa, opb, layout_string) \ - .def(py::init&, int64_t, \ - std::vector&, int64_t, \ - double, \ - std::vector&, int64_t, \ - int64_t>()) - -#define REGISTER_BATCHED_GEMM_FOR_ALL_TRANSAB(dtype) \ - REGISTER_BATCHED_GEMM(dtype, BlasOp::N, BlasOp::N, "NN"); \ - REGISTER_BATCHED_GEMM(dtype, BlasOp::N, BlasOp::T, "NT"); \ - REGISTER_BATCHED_GEMM(dtype, BlasOp::T, BlasOp::N, "TN"); \ - REGISTER_BATCHED_GEMM(dtype, BlasOp::T, BlasOp::T, "TT"); - -#define REGISTER_STRIDED_BATCHED_GEMM(dtype, opa, opb, layout_string) \ - REGISTER_OP_COMMON(StridedBatchedGemmTunable, dtype, opa, opb, layout_string) \ - .def(py::init()) - -#define REGISTER_STRIDED_BATCHED_GEMM_FOR_ALL_TRANSAB(dtype) \ - REGISTER_STRIDED_BATCHED_GEMM(dtype, BlasOp::N, BlasOp::N, "NN"); \ - REGISTER_STRIDED_BATCHED_GEMM(dtype, BlasOp::N, BlasOp::T, "NT"); \ - REGISTER_STRIDED_BATCHED_GEMM(dtype, BlasOp::T, BlasOp::N, "TN"); \ - REGISTER_STRIDED_BATCHED_GEMM(dtype, BlasOp::T, BlasOp::T, "TT"); - -KE_REGISTER(m) { - REGISTER_GEMM_FOR_ALL_TRANSAB(float); - REGISTER_GEMM_FOR_ALL_TRANSAB(half); - - REGISTER_BATCHED_GEMM_FOR_ALL_TRANSAB(float); - REGISTER_BATCHED_GEMM_FOR_ALL_TRANSAB(half); - - REGISTER_STRIDED_BATCHED_GEMM_FOR_ALL_TRANSAB(float); - REGISTER_STRIDED_BATCHED_GEMM_FOR_ALL_TRANSAB(half); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu deleted file mode 100644 index 6af163ab94b10..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include - -#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh" -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_tunable_op.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -namespace py = pybind11; -using onnxruntime::contrib::rocm::GetGroupNormWorkspaceSizeInBytes; -namespace onnxruntime { - -template -class GroupNormNHWC : public IKernelExplorer { - public: - GroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias, - DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu, - bool broadcast_skip, int channels_per_block) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), - static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), - epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, - channels_per_block) { - type_string_ = "GroupNormNHWC_" + std::to_string(ThreadsPerBlock) + "_" + std::to_string(VecSize); - } - - void Run() override { - ORT_THROW_IF_ERROR(op_(¶ms_)); - } - - std::vector ListOps() const { - return {type_string_}; - } - - bool SelectOp(const std::string& name) { - Status status = op_.IsSupported(¶ms_); - return status.IsOK() && name == type_string_; - } - - private: - using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; - ParamsT params_{}; - contrib::rocm::GroupNormNHWCOp op_{}; - std::string type_string_{}; -}; - -template -class GroupNormNHWCStaticSelection : public IKernelExplorer { - public: - GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, - DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, - float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, - bool use_silu, bool broadcast_skip, int channels_per_block) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), - static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), - epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, - channels_per_block) { - type_string_ = "GroupNormNHWCStaticSelection"; - } - - void Run() override { - ORT_THROW_IF_ERROR((contrib::rocm::GroupNormNHWCStaticSelection(¶ms_))); - } - - std::vector ListOps() const { - return {type_string_}; - } - - bool SelectOp(const std::string& name) { - Status status = contrib::rocm::GroupNormNHWCStaticSelection(¶ms_); - return status.IsOK() && name == type_string_; - } - - private: - using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; - ParamsT params_{}; - std::string type_string_{}; -}; - -template -class GroupNormNHWCTunable : public IKernelExplorer { - public: - GroupNormNHWCTunable(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, - DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, - float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, - bool use_silu, bool broadcast_skip, int channels_per_block) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), - static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), - epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, - channels_per_block) { - params_.TuningContext()->EnableTunableOpAndTuning(); - } - - void Run() override { - ORT_THROW_IF_ERROR(op_(¶ms_)); - } - - std::vector ListOps() const { - return {"GroupNormNHWCTunable"}; - } - - bool SelectOp(const std::string& name) { - return name == "GroupNormNHWCTunable"; - } - - private: - using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; - ParamsT params_{}; - contrib::rocm::GroupNormNHWCTunableOp op_{}; -}; - -#ifdef USE_COMPOSABLE_KERNEL -template -class CKGroupNormNHWC : public IKernelExplorer { - public: - CKGroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, - DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, - float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, - bool use_silu, bool broadcast_skip, int channels_per_block) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), - static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), - epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, - channels_per_block) { - for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; - using OpT = rocm::tunable::Op; - ParamsT params_{}; - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; -#endif // USE_COMPOSABLE_KERNEL - -#ifdef USE_TRITON_KERNEL -template -class GroupNormNHWCTriton : public IKernelExplorer { - public: - GroupNormNHWCTriton(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, - DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, - float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, - bool use_silu, bool broadcast_skip, int channels_per_block) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), - static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), - epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, - channels_per_block) { - for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps()) { - name_strings_.emplace_back(name); - ops_.emplace_back(std::move(op)); - } - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return name_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (name_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; - using OpT = rocm::tunable::Op; - ParamsT params_{}; - std::vector ops_; - std::vector name_strings_; - size_t selected_op_{}; -}; -#endif // USE_TRITON_KERNEL - -#define REGISTER_OP(name, type, threads_per_block, vec_size) \ - py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run) \ - .def("ListOps", &name::ListOps) \ - .def("SelectOp", &name::SelectOp); - -#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \ - REGISTER_OP(name, type, threads_per_block, 1) \ - REGISTER_OP(name, type, threads_per_block, 2) \ - REGISTER_OP(name, type, threads_per_block, 4) - -#define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name, type) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 64) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 128) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 192) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 256) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) - -#define REGISTER_COMMON(name, type, ...) \ - py::class_>(m, name) \ - .def(py::init()) \ - .def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \ - .def("Profile", &type<__VA_ARGS__>::Profile) \ - .def("Run", &type<__VA_ARGS__>::Run) \ - .def("ListOps", &type<__VA_ARGS__>::ListOps) \ - .def("SelectOp", &type<__VA_ARGS__>::SelectOp); - -#define REGISTER_OP_TYPED(name, type) \ - REGISTER_COMMON(#name "_" #type, name, type) - -#define REGISTER_CK(type, with_silu, silu_suffix) \ - REGISTER_COMMON("CKGroupNormNHWC" silu_suffix "_" #type, CKGroupNormNHWC, type, with_silu) - -#define REGISTER_TRITON(type, with_silu, silu_suffix) \ - REGISTER_COMMON("GroupNormNHWCTriton" silu_suffix "_" #type, GroupNormNHWCTriton, type, with_silu) - -KE_REGISTER(m) { - REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWC, half); - REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWC, float); - - REGISTER_OP_TYPED(GroupNormNHWCTunable, half); - REGISTER_OP_TYPED(GroupNormNHWCTunable, float); - - REGISTER_OP_TYPED(GroupNormNHWCStaticSelection, half); - REGISTER_OP_TYPED(GroupNormNHWCStaticSelection, float); - -#ifdef USE_COMPOSABLE_KERNEL - REGISTER_CK(half, false, "Pass"); - REGISTER_CK(half, true, "Silu"); - REGISTER_CK(float, false, "Pass"); - REGISTER_CK(float, true, "Silu"); -#endif // USE_COMPOSABLE_KERNEL - -#ifdef USE_TRITON_KERNEL - REGISTER_TRITON(half, false, "Pass"); - REGISTER_TRITON(half, true, "Silu"); - REGISTER_TRITON(float, false, "Pass"); - REGISTER_TRITON(float, true, "Silu"); -#endif -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/matmul_4bits.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/matmul_4bits.cu deleted file mode 100644 index cc9c2ed2862c9..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/matmul_4bits.cu +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include - -#include - -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" -#include "contrib_ops/rocm/quantization/matmul_nbits.cuh" - -namespace py = pybind11; - -namespace onnxruntime { - -// Extend the OpParams so that all specializations have the same parameter passing interface -template -struct MatrixFloatInt4Params : rocm::tunable::OpParams { - std::string Signature() const override { return std::to_string(n_); } - - T* output_; - const T* a_; - const uint8_t* b_; - const T* scales_; - const uint8_t* zero_points_; - int m_; - int n_; - int k_; -}; - -template -class MatrixFloatInt4 : public IKernelExplorer { - public: - MatrixFloatInt4(DeviceArray& output, - DeviceArray& a, - DeviceArray& b, - DeviceArray& scales, - int m, int n, int k) { - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.output_ = static_cast(output.ptr()); - params_.a_ = static_cast(a.ptr()); - params_.b_ = static_cast(b.ptr()); - params_.scales_ = static_cast(scales.ptr()); - params_.zero_points_ = nullptr; - params_.m_ = m; - params_.n_ = n; - params_.k_ = k; - - HIP_CALL_THROW(hipGetDeviceProperties(&device_prop_, 0)); - } - - MatrixFloatInt4(DeviceArray& output, - DeviceArray& a, - DeviceArray& b, - DeviceArray& scales, - DeviceArray& zeropoints, - int m, int n, int k) : MatrixFloatInt4(output, a, b, scales, m, n, k) { - params_.zero_points_ = static_cast(zeropoints.ptr()); - } - - void Run() override { - contrib::rocm::TryMatMul4Bits( - params_.output_, - params_.a_, - params_.b_, - params_.scales_, - params_.zero_points_, - params_.m_, - params_.n_, - params_.k_, - 32, - static_cast(device_prop_.sharedMemPerBlock), - params_.StreamHandle()); - } - - private: - // A VectorAddOp is a callable that can process const VectorAddParams* - using ParamsT = MatrixFloatInt4Params; - ParamsT params_{}; - hipDeviceProp_t device_prop_; -}; - -#define REGISTER_OP(name, type) \ - py::class_>(m, #name "_" #type) \ - .def(py::init()) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run); - -KE_REGISTER(m) { - REGISTER_OP(MatrixFloatInt4, half); - REGISTER_OP(MatrixFloatInt4, float); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu deleted file mode 100644 index ec353f7e91f86..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include - -#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -namespace py = pybind11; - -namespace onnxruntime { - -template -class SkipLayerNormSmall : public IKernelExplorer { - public: - SkipLayerNormSmall(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip, - DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, - float epsilon, int hidden_size, int element_count) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), - static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} - - void Run() override { - ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp(¶ms_))); - } - - bool IsSupported() { - Status status = contrib::rocm::SkipLayerNormSmallOp(¶ms_); - return status.IsOK(); - } - - private: - using ParamsT = contrib::rocm::SkipLayerNormParams; - ParamsT params_{}; -}; - -template -class SkipLayerNormRegular : public IKernelExplorer { - public: - SkipLayerNormRegular(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip, - DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, - float epsilon, int hidden_size, int element_count) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), - static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} - - void Run() override { - ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormRegularOp(¶ms_))); - } - - bool IsSupported() { - Status status = contrib::rocm::SkipLayerNormRegularOp(¶ms_); - return status.IsOK(); - } - - private: - using ParamsT = contrib::rocm::SkipLayerNormParams; - ParamsT params_{}; -}; - -template -class SkipLayerNormStaticSelection : public IKernelExplorer { - public: - SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, - DeviceArray& skip, DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, - float epsilon, int hidden_size, int element_count) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), - static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} - - void Run() override { - ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormStaticSelection(¶ms_))); - } - - bool IsSupported() { - Status status = contrib::rocm::SkipLayerNormStaticSelection(¶ms_); - return status.IsOK(); - } - - private: - using ParamsT = contrib::rocm::SkipLayerNormParams; - ParamsT params_{}; -}; - -template -class SkipLayerNormTunable : public IKernelExplorer { - public: - SkipLayerNormTunable(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip, - DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, - float epsilon, int hidden_size, int element_count) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), - static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), - static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) { - params_.TuningContext()->EnableTunableOpAndTuning(); - } - - void Run() override { - ORT_THROW_IF_ERROR(op_(¶ms_)); - } - - bool IsSupported() { - return true; - } - - private: - using ParamsT = contrib::rocm::SkipLayerNormParams; - ParamsT params_{}; - contrib::rocm::SkipLayerNormTunableOp op_{}; -}; - -#define REGISTER_OP_COMMON(name, type, ...) \ - py::class_>(m, name) \ - .def(py::init()) \ - .def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \ - .def("Profile", &type<__VA_ARGS__>::Profile) \ - .def("Run", &type<__VA_ARGS__>::Run) \ - .def("IsSupported", &type<__VA_ARGS__>::IsSupported); - -#define REGISTER_OP(name, type, threads_per_block, vec_size) \ - REGISTER_OP_COMMON("Simplified" #name "_" #type "_" #threads_per_block "_" #vec_size, name, type, threads_per_block, vec_size, true) \ - REGISTER_OP_COMMON(#name "_" #type "_" #threads_per_block "_" #vec_size, name, type, threads_per_block, vec_size, false) - -#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \ - REGISTER_OP(name, type, threads_per_block, 1) \ - REGISTER_OP(name, type, threads_per_block, 2) \ - REGISTER_OP(name, type, threads_per_block, 4) \ - REGISTER_OP(name, type, threads_per_block, 8) \ - REGISTER_OP(name, type, threads_per_block, 16) - -#define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name, type) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 64) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 128) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 192) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 256) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 448) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 512) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 576) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 640) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 704) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 768) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 832) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 896) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 1024) - -#define REGISTER_COMMON(name, type, ...) \ - py::class_>(m, name) \ - .def(py::init()) \ - .def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \ - .def("Profile", &type<__VA_ARGS__>::Profile) \ - .def("Run", &type<__VA_ARGS__>::Run) \ - .def("IsSupported", &type<__VA_ARGS__>::IsSupported); - -#define REGISTER_OP_TYPED(name, type) \ - REGISTER_COMMON("Simplified" #name "_" #type, name, type, true) \ - REGISTER_COMMON(#name "_" #type, name, type, false) - -KE_REGISTER(m) { - REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half); - REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float); - - REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegular, half); - REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegular, float); - - REGISTER_OP_TYPED(SkipLayerNormTunable, half); - REGISTER_OP_TYPED(SkipLayerNormTunable, float); - - REGISTER_OP_TYPED(SkipLayerNormStaticSelection, half); - REGISTER_OP_TYPED(SkipLayerNormStaticSelection, float); -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu deleted file mode 100644 index 6cc59ee2579c2..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include - -#include -#include -#include - -#include "core/providers/rocm/math/softmax_ck.cuh" -#include "core/providers/rocm/math/softmax_tunable_op.cuh" -#include "core/providers/rocm/shared_inc/accumulation_type.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" -#include "core/providers/rocm/math/softmax_triton.cuh" - -namespace py = pybind11; - -namespace onnxruntime { - -template -class SoftmaxBlockwise : public IKernelExplorer { - public: - SoftmaxBlockwise(DeviceArray& output, DeviceArray& input, int softmax_elements, - int input_stride, int output_stride, int batch_count, bool is_log_softmax) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(input.ptr()), - softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) { - type_string_ = "SoftmaxBlockwise_" + std::to_string(VecSize); - } - - void Run() override { - ORT_THROW_IF_ERROR((rocm::SoftmaxBlockwiseOp, VecSize>(¶ms_))); - } - - std::vector ListOps() const { - return {type_string_}; - } - - bool SelectOp(const std::string& name) { - Status status = rocm::SoftmaxBlockwiseOp, VecSize>(¶ms_); - return status.IsOK() && name == type_string_; - } - - private: - using ParamsT = rocm::SoftmaxParams; - ParamsT params_{}; - std::string type_string_{}; -}; - -template -class SoftmaxWarpwiseStaticSelection : public IKernelExplorer { - public: - SoftmaxWarpwiseStaticSelection(DeviceArray& output, DeviceArray& input, int softmax_elements, - int input_stride, int output_stride, int batch_count, bool is_log_softmax) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(input.ptr()), - softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {} - - void Run() override { - ORT_THROW_IF_ERROR((rocm::SoftmaxWarpwiseStaticSelection>(¶ms_))); - } - - std::vector ListOps() const { - return {"SoftmaxWarpwiseStaticSelection"}; - } - - bool SelectOp(const std::string& name) { - auto status = rocm::SoftmaxWarpwiseStaticSelection>(¶ms_); - return status.IsOK() && name == "SoftmaxWarpwiseStaticSelection"; - } - - private: - using ParamsT = rocm::SoftmaxParams; - ParamsT params_{}; -}; - -template -class SoftmaxBlockwiseStaticSelection : public IKernelExplorer { - public: - SoftmaxBlockwiseStaticSelection(DeviceArray& output, DeviceArray& input, int softmax_elements, - int input_stride, int output_stride, int batch_count, bool is_log_softmax) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(input.ptr()), - softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {} - - void Run() override { - ORT_THROW_IF_ERROR((rocm::SoftmaxBlockwiseStaticSelection>(¶ms_))); - } - - std::vector ListOps() const { - return {"SoftmaxBlockwiseStaticSelection"}; - } - - bool SelectOp(const std::string& name) { - return name == "SoftmaxBlockwiseStaticSelection"; - } - - private: - using ParamsT = rocm::SoftmaxParams; - ParamsT params_{}; -}; - -template -class SoftmaxTunable : public IKernelExplorer { - public: - SoftmaxTunable(DeviceArray& output, DeviceArray& input, int softmax_elements, - int input_stride, int output_stride, int batch_count, bool is_log_softmax) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(input.ptr()), - softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) { - params_.TuningContext()->EnableTunableOpAndTuning(); - } - - void Run() override { - WithMaxTuningDurationMs max_duration(TuningContext(), 250); - ORT_THROW_IF_ERROR(op_(¶ms_)); - } - - std::vector ListOps() const { - return {"SoftmaxTunable"}; - } - - bool SelectOp(const std::string& name) { - return name == "SoftmaxTunable"; - } - - private: - using ParamsT = rocm::SoftmaxParams; - ParamsT params_{}; - rocm::SoftmaxTunableOp> op_{}; -}; - -#ifdef USE_COMPOSABLE_KERNEL -template -class CKSoftmax : public IKernelExplorer { - public: - CKSoftmax(DeviceArray& output, DeviceArray& input, int softmax_elements, - int input_stride, int output_stride, int batch_count, bool is_log_softmax) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(input.ptr()), - softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) { - for (auto&& [type_string, op] : rocm::GetCKSoftmaxTypeStringAndOps>()) { - type_strings_.emplace_back(std::move(type_string)); - ops_.emplace_back(std::move(op)); - } - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return type_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (type_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - using ParamsT = rocm::SoftmaxParams; - using OpT = rocm::tunable::Op; - ParamsT params_{}; - std::vector ops_; - std::vector type_strings_; - size_t selected_op_{}; -}; -#endif // USE_COMPOSABLE_KERNEL - -#ifdef USE_TRITON_KERNEL -template -class SoftmaxTriton : public IKernelExplorer { - public: - SoftmaxTriton(DeviceArray& output, DeviceArray& input, int softmax_elements, - int input_stride, int output_stride, int batch_count, bool is_log_softmax) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(input.ptr()), - softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) { - for (auto&& [name, op] : rocm::GetSoftmaxTritonOps()) { - name_strings_.emplace_back(name); - ops_.emplace_back(std::move(op)); - } - } - - void Run() override { - ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); - } - - std::vector ListOps() const { - return name_strings_; - } - - bool SelectOp(const std::string& name) { - for (size_t i = 0; i < ops_.size(); i++) { - if (name_strings_[i] == name) { - selected_op_ = i; - Status status = ops_[i](¶ms_); - return status.IsOK(); - } - } - - ORT_THROW("Cannot find implementation ", name); - } - - private: - using ParamsT = rocm::SoftmaxParams; - using OpT = rocm::tunable::Op; - ParamsT params_{}; - std::vector ops_; - std::vector name_strings_; - size_t selected_op_{}; -}; - -#endif // USE_TRITON_KERNEL - -#define REGISTER_OP(name, type, vec_size) \ - py::class_>(m, #name "_" #type "_" #vec_size) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run) \ - .def("ListOps", &name::ListOps) \ - .def("SelectOp", &name::SelectOp); - -#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type) \ - REGISTER_OP(name, type, 1) \ - REGISTER_OP(name, type, 2) \ - REGISTER_OP(name, type, 4) \ - REGISTER_OP(name, type, 8) \ - REGISTER_OP(name, type, 16) - -#define REGISTER_OP_TYPED(name, type) \ - py::class_>(m, #name "_" #type) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run) \ - .def("ListOps", &name::ListOps) \ - .def("SelectOp", &name::SelectOp); - -KE_REGISTER(m) { - REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, half); - REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, float); - - REGISTER_OP_TYPED(SoftmaxWarpwiseStaticSelection, half); - REGISTER_OP_TYPED(SoftmaxWarpwiseStaticSelection, float); - - REGISTER_OP_TYPED(SoftmaxBlockwiseStaticSelection, half); - REGISTER_OP_TYPED(SoftmaxBlockwiseStaticSelection, float); - - REGISTER_OP_TYPED(SoftmaxTunable, half); - REGISTER_OP_TYPED(SoftmaxTunable, float); -} - -#ifdef USE_COMPOSABLE_KERNEL -KE_REGISTER(m) { - REGISTER_OP_TYPED(CKSoftmax, half); - REGISTER_OP_TYPED(CKSoftmax, float); -} -#endif // USE_COMPOSABLE_KERNEL - -#ifdef USE_TRITON_KERNEL -KE_REGISTER(m) { - REGISTER_OP_TYPED(SoftmaxTriton, half); - REGISTER_OP_TYPED(SoftmaxTriton, float); -} -#endif // USE_TRITON_KERNEL - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py deleted file mode 100644 index bfe13fac2a148..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py +++ /dev/null @@ -1,214 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import re -from dataclasses import dataclass -from itertools import product - -import kernel_explorer as ke -import numpy as np -import pytest -from utils import dtype_to_bytes, root_mean_square, standardization - - -def get_bert_sizes_test(): - batch_sizes = [1, 8] - seq_lens = [64, 256] - hidden_sizes = [1, 2, 3, 4, 5, 7, 8, 9, 13, 32, 63, 64, 65, 127, 128, 129, 177, 256, 1023, 1024] - return product(batch_sizes, seq_lens, hidden_sizes) - - -def get_bert_sizes_profile(): - batch_sizes = [1, 8, 128, 256] - seq_lens = [64, 128, 256, 384] - hidden_sizes = [768, 1024] - return product(batch_sizes, seq_lens, hidden_sizes) - - -def dtype_to_funcs(dtype, simplified=False): - skip_layer_norm_prefix = "SimplifiedSkipLayerNorm" if simplified else "SkipLayerNorm" - type_map = { - "float16": list(filter(lambda x: re.match(f"{skip_layer_norm_prefix}.*_half.*", x), dir(ke))), - "float32": list(filter(lambda x: re.match(f"{skip_layer_norm_prefix}.*_float.*", x), dir(ke))), - } - return type_map[dtype] - - -def skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon): - val = input_x + skip + bias - output = standardization(val, 2, epsilon) - output = output * gamma + beta - return output, val - - -def simplified_skip_layer_norm(input_x, skip, bias, gamma, epsilon): - val = input_x + skip + bias - rms = root_mean_square(val, 2, epsilon) - output = (val / rms) * gamma - return output, val - - -@ke.dispatchable(pattern_arg=4) -def run_skip_layer_norm( - batch_size: int, seq_len: int, hidden_size: int, dtype: str, func, simplified=False, has_optional_output=False -): - np.random.seed(0) - input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) - skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) - bias = np.random.rand(hidden_size).astype(dtype) - gamma = np.random.rand(hidden_size).astype(dtype) - beta = np.random.rand(hidden_size).astype(dtype) - # Because of rocm FMAs calculation issue with float16, epsilon should be larger when hidden_size is small - epsilon = 0.05 if hidden_size < 8 else 0.0005 - output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) - output_optional = ( - np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) - if has_optional_output - else np.empty((0), dtype=dtype) - ) - - input_d = ke.DeviceArray(input_x) - skip_d = ke.DeviceArray(skip) - bias_d = ke.DeviceArray(bias) - gamma_d = ke.DeviceArray(gamma) - beta_d = ke.DeviceArray(beta) - y_d = ke.DeviceArray(output_y) - optional_d = ke.DeviceArray(output_optional) - f = getattr(ke, func) - - my_op = f( - y_d, - optional_d, - input_d, - skip_d, - gamma_d, - beta_d, - bias_d, - epsilon, - hidden_size, - batch_size * seq_len * hidden_size, - ) - if my_op.IsSupported(): - my_op.Run() - - y_d.UpdateHostNumpyArray() - optional_d.UpdateHostNumpyArray() - - if simplified: - y_ref, y_optional = simplified_skip_layer_norm(input_x, skip, bias, gamma, epsilon) - else: - y_ref, y_optional = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon) - np.testing.assert_almost_equal(y_ref, output_y, decimal=1) - if has_optional_output: - np.testing.assert_almost_equal(y_optional, output_optional, decimal=3) - - -dtypes = ["float32", "float16"] -simplified = [True, False] - - -@pytest.mark.parametrize("bert_sizes", get_bert_sizes_test()) -@pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("simplified", simplified) -def test_skip_layer_norm(bert_sizes, dtype, simplified): - for func in dtype_to_funcs(dtype, simplified): - run_skip_layer_norm(*bert_sizes, dtype, func, simplified) - - -@dataclass -class SkipLayerNormMetric(ke.BandwidthMetric): - batch_size: int - seq_len: int - hidden_size: int - - def report(self): - common = f"{self.dtype} batch_size={self.batch_size:<4} seq_len={self.seq_len:<4} hidden_size={self.hidden_size:<4} {self.name}" - if self.duration > 0: - return f"{self.duration:6.2f} us, {self.gbps:5.2f} GB/s " + common - return "not supported " + common - - -@ke.dispatchable(pattern_arg=4) -def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output): - np.random.seed(0) - input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) - skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) - gamma = np.random.rand(hidden_size).astype(dtype) - beta = np.random.rand(hidden_size).astype(dtype) - bias = np.random.rand(hidden_size).astype(dtype) - epsilon = 0.0005 - output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) - output_optional = ( - np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) - if has_optional_output - else np.empty((0), dtype=dtype) - ) - - input_d = ke.DeviceArray(input_x) - skip_d = ke.DeviceArray(skip) - gamma_d = ke.DeviceArray(gamma) - beta_d = ke.DeviceArray(beta) - bias_d = ke.DeviceArray(bias) - y_d = ke.DeviceArray(output_y) - optional_d = ke.DeviceArray(output_optional) - f = getattr(ke, func) - - my_op = f( - y_d, - optional_d, - input_d, - skip_d, - gamma_d, - beta_d, - bias_d, - epsilon, - hidden_size, - batch_size * seq_len * hidden_size, - ) - - duration_ms = -1 - if my_op.IsSupported(): - duration_ms = my_op.Profile() - total_bytes = (input_x.size * 3 + bias.size * 3) * dtype_to_bytes(dtype) - - ke.report(SkipLayerNormMetric(func, dtype, duration_ms, total_bytes, batch_size, seq_len, hidden_size)) - - -@ke.dispatchable -def profile_with_args(batch_size, seq_len, hidden_size, dtype, has_optional_output=False, simplified=False): - with ke.benchmark(): - for func in dtype_to_funcs(dtype, simplified): - profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output) - - -def profile(): - for dtype in dtypes: - for bert_size in get_bert_sizes_profile(): - profile_with_args(*bert_size, dtype) - print() - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("batch_size", type=int) - group.add_argument("seq_len", type=int) - group.add_argument("hidden_size", type=int) - group.add_argument("dtype", choices=dtypes) - group.add_argument("--has_optional_output", "-o", action="store_true") - group.add_argument("--simplified", "-s", action="store_true", default=False) - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch( - args.batch_size, - args.seq_len, - args.hidden_size, - args.dtype, - args.has_optional_output, - args.simplified, - ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py deleted file mode 100644 index 3a7e4442108f5..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py +++ /dev/null @@ -1,141 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import re -from dataclasses import dataclass -from itertools import product - -import kernel_explorer as ke -import numpy as np -import pytest -from utils import dtype_to_bytes, dtype_to_suffix, softmax - - -def get_test_sizes(): - batch_count = [1, 8] - softmax_elements = [1, 2, 3, 4, 5, 7, 8, 9, 11, 16, 31, 32, 33, 64, 65, 127, 128, 1024, 1025, 2048, 4096] - is_log_softmax = [True, False] - return product(batch_count, softmax_elements, is_log_softmax) - - -def dtype_to_funcs(dtype): - type_map = { - "float16": list(filter(lambda x: re.match("Softmax.*_half.*", x), dir(ke))), - "float32": list(filter(lambda x: re.match("Softmax.*_float.*", x), dir(ke))), - } - return type_map[dtype] - - -def _test_softmax(batch_count, softmax_elements, is_log_softmax, dtype, func): - np.random.seed(0) - x = np.random.rand(batch_count, softmax_elements).astype(dtype) - y = np.random.rand(batch_count, softmax_elements).astype(dtype) - - x_d = ke.DeviceArray(x) - y_d = ke.DeviceArray(y) - y_ref = softmax(x, is_log_softmax=is_log_softmax) - - softmax_func = getattr(ke, func) - softmax_op = softmax_func( - y_d, x_d, softmax_elements, softmax_elements, softmax_elements, batch_count, is_log_softmax - ) - for impl in softmax_op.ListOps(): - if not softmax_op.SelectOp(impl): - continue - - softmax_op.Run() - y_d.UpdateHostNumpyArray() - - np.testing.assert_allclose(y_ref, y, rtol=1e-02, err_msg=func) - - -dtypes = ["float16", "float32"] - - -@pytest.mark.parametrize("batch_count, softmax_elements, is_log_softmax", get_test_sizes()) -@pytest.mark.parametrize("dtype", dtypes) -@ke.dispatchable -def test_softmax(batch_count, softmax_elements, is_log_softmax, dtype): - for f in dtype_to_funcs(dtype): - _test_softmax(batch_count, softmax_elements, is_log_softmax, dtype, f) - - -@pytest.mark.parametrize("batch_count, softmax_elements, is_log_softmax", get_test_sizes()) -@pytest.mark.parametrize("dtype", dtypes) -@ke.dispatchable -def test_ck_softmax(batch_count, softmax_elements, is_log_softmax, dtype): - ck_f_name = "CKSoftmax_" + dtype_to_suffix(dtype) - _test_softmax(batch_count, softmax_elements, is_log_softmax, dtype, ck_f_name) - - -@dataclass -class SoftmaxMetric(ke.BandwidthMetric): - batch_count: int - softmax_elements: int - is_log_softmax: bool - - def report(self): - common = f"{self.dtype} batch_count={self.batch_count:<4} softmax_elements={self.softmax_elements:<4} is_log_softmax={self.is_log_softmax:<4} {self.name}" - if self.duration > 0: - return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s " + common - return "not supported " + common - - -@ke.dispatchable(pattern_arg=4) -def profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, func): - np.random.seed(0) - x = np.random.rand(batch_count, softmax_elements).astype(dtype) - y = np.random.rand(batch_count, softmax_elements).astype(dtype) - - x_d = ke.DeviceArray(x) - y_d = ke.DeviceArray(y) - - softmax_func = getattr(ke, func) - softmax_op = softmax_func( - y_d, x_d, softmax_elements, softmax_elements, softmax_elements, batch_count, is_log_softmax - ) - - for impl in softmax_op.ListOps(): - duration_ms = -1 - if softmax_op.SelectOp(impl): - duration_ms = softmax_op.Profile() - total_bytes = 2 * batch_count * softmax_elements * dtype_to_bytes(dtype) - - ke.report(SoftmaxMetric(impl, dtype, duration_ms, total_bytes, batch_count, softmax_elements, is_log_softmax)) - - -@ke.dispatchable -def profile_with_args(batch_count, softmax_elements, is_log_softmax, dtype): - with ke.benchmark(): - for func in dtype_to_funcs(dtype): - profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, func) - # ck function - ck_f_name = "CKSoftmax_" + dtype_to_suffix(dtype) - profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, ck_f_name) - - -profile_size = [(1, 2048), (8, 2048), (65536, 4096)] - - -def profile(): - for dtype in dtypes: - for batch_count, softmax_elements in profile_size: - profile_with_args(batch_count, softmax_elements, False, dtype) - print() - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("batch_count", type=int) - group.add_argument("softmax_elements", type=int) - group.add_argument("is_log_softmax", type=int) - group.add_argument("dtype", choices=dtypes) - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch(args.batch_count, args.softmax_elements, args.is_log_softmax, args.dtype) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py deleted file mode 100644 index 4c7ed67f44c6b..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py +++ /dev/null @@ -1,256 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import os -from dataclasses import dataclass -from itertools import product - -import kernel_explorer as ke -import numpy as np -import pytest -from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix - -max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", "64")) - - -def dtype_to_suffix(dtype): - return { - "float32": "float", - "float16": "half", - }[dtype] - - -@ke.dispatchable -def _test_strided_batched_gemm( - func, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int, alpha=1.0, beta=0.0 -): - assert dtype in ["float32", "float16"] - - a_shape = (k, m) if transa else (m, k) - b_shape = (n, k) if transb else (k, n) - - np.random.seed(0) - a = (np.random.rand(batch, *a_shape) + 0.5).astype(dtype).astype("float64") - b = (np.random.rand(batch, *b_shape) + 0.5).astype(dtype).astype("float64") - tmp_a = a.swapaxes(1, 2) if transa else a - tmp_b = b.swapaxes(1, 2) if transb else b - intermediate_c = matmul(tmp_a, tmp_b) - if alpha == 1.0 and beta == 0.0: # fast path - ref_c = intermediate_c - else: - ref_c = alpha * intermediate_c + beta * np.ones_like(intermediate_c) - - bounds = [get_gemm_bound(dtype, a[i], b[i], ref_c[i], transa, transb, a_b_positive=True) for i in range(batch)] - - a = a.astype(dtype) - b = b.astype(dtype) - - my_c = np.ones((batch, m, n), dtype=dtype) - dev_a = ke.DeviceArray(a) - dev_b = ke.DeviceArray(b) - dev_c = ke.DeviceArray(my_c) - - opa = ke.blas_op.T if transa else ke.blas_op.N - opb = ke.blas_op.T if transb else ke.blas_op.N - lda = a_shape[1] - ldb = b_shape[1] - ldc = n - stride_a = m * k - stride_b = k * n - stride_c = m * n - my_gemm = func( - opa, opb, m, n, k, alpha, dev_a, lda, stride_a, dev_b, ldb, stride_b, beta, dev_c, ldc, stride_c, batch - ) - - failures = {} - print( - f"dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} batch={batch:<3} max bound: {max(bounds)}" - ) - - for impl in my_gemm.ListOps(): - if not my_gemm.SelectOp(impl): - continue - - # Restore C Array - my_c.fill(1.0) - dev_c.UpdateDeviceArray() - my_gemm.Run() - dev_c.UpdateHostNumpyArray() - - for i in range(batch): - try: - np.testing.assert_allclose(my_c[i], ref_c[i], rtol=bounds[i]) - except Exception as err: - header = "*" * 30 + impl + "*" * 30 - print(header, bounds[i]) - print(err) - print("*" * len(header)) - failures[impl] = str(err) - - if failures: - raise Exception(failures) - - -dtypes = ["float32", "float16"] -all_transabs = list(product([True, False], repeat=2)) - - -@pytest.mark.parametrize("batch", [1, max_batch_size]) -@pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -@ke.dispatchable -def test_rocblas_gemm_all_cases(dtype, transa, transb, m, n, k, batch): - wrapper_name = "RocBlasStridedBatchedGemm_" + dtype_to_suffix(dtype) - _test_strided_batched_gemm(getattr(ke, wrapper_name), dtype, transa, transb, m, n, k, batch) - - -@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") -@pytest.mark.parametrize("batch", [1, max_batch_size]) -@pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -@ke.dispatchable -def test_ck_gemm_all_cases(dtype, transa, transb, m, n, k, batch): - wrapper_name = f"CKStridedBatchedGemm_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" - _test_strided_batched_gemm(getattr(ke, wrapper_name), dtype, transa, transb, m, n, k, batch) - - -# Tunable is basically wrapped around of rocblas and ck gemm, so no need for full tests -@pytest.mark.parametrize("batch", [1, max_batch_size]) -@pytest.mark.parametrize("m, n, k", get_gemm_bert_sizes(full=False)) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -@ke.dispatchable -def test_gemm_tunable_bert_cases(dtype, transa, transb, m, n, k, batch): - wrapper_name = f"StridedBatchedGemmTunable_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" - _test_strided_batched_gemm(getattr(ke, wrapper_name), dtype, transa, transb, m, n, k, batch) - - -@pytest.mark.parametrize("alpha, beta", [(0.5, 0.5)]) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -def test_rocblas_gemm_alpha_beta(dtype, transa, transb, alpha, beta): - wrapper_name = "RocBlasStridedBatchedGemm_" + dtype_to_suffix(dtype) - _test_strided_batched_gemm( - getattr(ke, wrapper_name), dtype, transa, transb, 128, 256, 768, 16, alpha=alpha, beta=beta - ) - - -@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") -@pytest.mark.parametrize("alpha, beta", [(0.5, 0.5)]) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -def test_ck_gemm_alpha_beta(dtype, transa, transb, alpha, beta): - wrapper_name = f"CKStridedBatchedGemm_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" - _test_strided_batched_gemm( - getattr(ke, wrapper_name), dtype, transa, transb, 256, 128, 384, 8, alpha=alpha, beta=beta - ) - - -@pytest.mark.parametrize("alpha, beta", [(0.5, 0.5)]) -@pytest.mark.parametrize("transa, transb", all_transabs) -@pytest.mark.parametrize("dtype", dtypes) -def test_gemm_tunable_alpha_beta(dtype, transa, transb, alpha, beta): - wrapper_name = f"StridedBatchedGemmTunable_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" - _test_strided_batched_gemm( - getattr(ke, wrapper_name), dtype, transa, transb, 128, 512, 384, 4, alpha=alpha, beta=beta - ) - - -@dataclass -class StridedBatchedGemmMetric(ke.ComputeMetric): - transa: bool - transb: bool - m: int - n: int - k: int - batch: int - - def report(self): - common = ( - f"{self.dtype} {transab_to_suffix((self.transa, self.transb))} " - f"m={self.m:<4} n={self.n:<4} k={self.k:<4} batch={self.batch:<3} {self.name}" - ) - if self.duration <= 0: - return "not supported " + common - - return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common - - -@ke.dispatchable(pattern_arg=0) -def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int): - a_shape = (k, m) if transa else (m, k) - b_shape = (n, k) if transb else (k, n) - - np.random.seed(0) - a = (np.random.rand(batch, *a_shape) * 2 - 1).astype(dtype) - b = (np.random.rand(batch, *b_shape) * 2 - 1).astype(dtype) - my_c = np.zeros((batch, m, n), dtype=dtype) - - dev_a = ke.DeviceArray(a) - dev_b = ke.DeviceArray(b) - dev_c = ke.DeviceArray(my_c) - - opa = ke.blas_op.T if transa else ke.blas_op.N - opb = ke.blas_op.T if transb else ke.blas_op.N - lda = a_shape[1] - ldb = b_shape[1] - ldc = n - stride_a = m * k - stride_b = k * n - stride_c = m * n - alpha = 1.0 - beta = 0.0 - my_gemm = f(opa, opb, m, n, k, alpha, dev_a, lda, stride_a, dev_b, ldb, stride_b, beta, dev_c, ldc, stride_c, batch) - for impl in my_gemm.ListOps(): - duration_ms = -1 - if my_gemm.SelectOp(impl): - duration_ms = my_gemm.Profile() - FLOPs = batch * m * k * n * 2 # noqa: N806 - ke.report(StridedBatchedGemmMetric(impl, dtype, duration_ms, FLOPs, transa, transb, m, n, k, batch)) - - -@ke.dispatchable -def profile_with_args(dtype, transa, transb, m, n, k, batch): - dtype_suffix = "_" + dtype_to_suffix(dtype) - transab_suffix = "_" + transab_to_suffix((transa, transb)) - fn_rocblas = getattr(ke, "RocBlasStridedBatchedGemm" + dtype_suffix) - fn_ck = getattr(ke, "CKStridedBatchedGemm" + dtype_suffix + transab_suffix) - fn_tunable = getattr(ke, "StridedBatchedGemmTunable" + dtype_suffix + transab_suffix) - if ke.is_hipblaslt_available(): - fn_hipblaslt = getattr(ke, "StridedBatchedGemmHipBlasLt" + dtype_suffix + transab_suffix) - with ke.benchmark(): - profile_gemm_func(fn_rocblas, dtype, transa, transb, m, n, k, batch) - profile_gemm_func(fn_ck, dtype, transa, transb, m, n, k, batch) - profile_gemm_func(fn_tunable, dtype, transa, transb, m, n, k, batch) - if ke.is_hipblaslt_available(): - profile_gemm_func(fn_hipblaslt, dtype, transa, transb, m, n, k, batch) - print() - - -def profile(): - for dtype in dtypes: - for m, n, k in get_gemm_bert_sizes(full=False): - for batch in [1, 32, 64]: - profile_with_args(dtype, False, False, m, n, k, batch) - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("dtype", choices=dtypes) - group.add_argument("transa", choices="NT") - group.add_argument("transb", choices="NT") - group.add_argument("m", type=int) - group.add_argument("n", type=int) - group.add_argument("k", type=int) - group.add_argument("batch", type=int) - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch(args.dtype, args.transa == "T", args.transb == "T", args.m, args.n, args.k, args.batch) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/utils.py b/onnxruntime/python/tools/kernel_explorer/kernels/utils.py deleted file mode 100644 index cdbae640b05d5..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/utils.py +++ /dev/null @@ -1,150 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import os -from itertools import product - -import numpy as np -import scipy.special - - -def dtype_to_bytes(dtype): - type_map = { - "float8_e4m3fn": 1, - "float8_e4m3fnuz": 1, - "float8_e5m2": 1, - "float8_e5m2fnuz": 1, - "float16": 2, - "float32": 4, - "float64": 8, - } - return type_map[dtype] - - -def transab_to_suffix(transab): - return { - (True, True): "TT", - (True, False): "TN", - (False, True): "NT", - (False, False): "NN", - }[tuple(transab)] - - -def dtype_to_suffix(dtype): - return { - "float32": "float", - "float16": "half", - "float8_e4m3fn": "fp8e4m3fn", - "float8_e4m3fnuz": "fp8e4m3fnuz", - }[dtype] - - -def get_gemm_bound( - dtype: str, - a: np.ndarray, - b: np.ndarray, - c: np.ndarray, - transa: bool, - transb: bool, - a_b_positive=False, # if both a and b are positive matrix, we can skip coeff computation -): - k = b.shape[1] if transb else b.shape[0] - # The machine epsilon, unit roundoff, the smallest positive floating point number n such that the floating point - # number that represents 1 + n is greater than 1. - machine_eps = 2.0 ** -(24 if dtype == "float32" else 11) - - # The following implements error bound 5.7 in paper I. C. Ipsen and H. Zhou, “Probabilistic error analysis for - # Inner Products,” SIAM Journal on Matrix Analysis and Applications, vol. 41, no. 4, pp. 1726-1741, 2020. - # NOTE: the bound is not tight for float16 when k is large - if a_b_positive: - coeff = 1.0 - else: - absa_mul_absb = np.abs(a.T if transa else a) @ np.abs(b.T if transb else b) - coeff = np.max(absa_mul_absb / np.abs(c)) - gamma_2k = (1.0 + machine_eps) ** (2 * k) - 1.0 - bound_5_7 = coeff * np.sqrt(np.log(2 / 1e-10) * machine_eps * gamma_2k / 2) - bound = bound_5_7 - - return bound - - -def get_gemm_bert_sizes(full=True): - bert_base_sizes = [ - # m, n, k - (384, 768, 768), - (384, 768, 768 * 3), - (384, 768, 768 * 4), - (384, 768 * 4, 768), - (384, 1024, 1024), - (384, 1024, 1024 * 3), - (384, 1024, 1024 * 4), - (384, 1024 * 4, 1024), - ] - - # we then multiply m with the batch size - if full: - batch_sizes = [1, 64] - else: - batch_sizes = [1] - bert_sizes = [] - for bsz in batch_sizes: - bert_sizes.extend([(m * bsz, n, k) for m, n, k in bert_base_sizes]) - return bert_sizes - - -def get_gemm_basic_sizes(full=True): - if full: - return list(product([1, 3, 4, 16, 127, 128, 129, 133, 1024], repeat=3)) - - # ck has various impls to be tested, use the full basic cases will result too many cases to test. - # So we use a reduced combination here. - return list(product([1, 4, 127, 133], [3, 16, 128], [3, 129, 1024])) - - -def softmax(x, *, is_log_softmax=False, axis=-1): - x = x - np.max(x, axis=axis, keepdims=1) - if is_log_softmax: - return x - np.log(np.sum(np.exp(x), axis=axis, keepdims=1)) - return (np.exp(x)) / np.sum(np.exp(x), axis=axis, keepdims=1) - - -def _matmul(a, b): - if os.getenv("KERNEL_EXPLORER_TEST_USE_CUPY", "0") == "1": - import cupy as cp - - return (cp.asarray(a) @ cp.asarray(b)).get() - else: - return a @ b - - -def matmul(a, b, transa=False, transb=False): - return _matmul(a.T if transa else a, b.T if transb else b) - - -def fast_gelu(x, bias): - x = x + bias - y = 0.5 * x * (1 + np.tanh(0.797885 * x + 0.035677 * x * x * x)) - return y - - -def gelu(x, bias): - x = x + bias - return 0.5 * x * (1 + scipy.special.erf(x / np.sqrt(2))) - - -def relu(x, bias): - x = x + bias - return np.max(x, 0, keepdims=True) - - -def root_mean_square(x, axis, epsilon): - rms = np.sqrt(np.mean(np.square(x), axis=axis, keepdims=True) + epsilon) - return rms - - -def standardization(x, axis, epsilon): - mean = np.mean(x, axis=axis, keepdims=True) - variance = np.var(x, axis=axis, keepdims=True) - return (x - mean) / np.sqrt(variance + epsilon) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cu b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cu deleted file mode 100644 index eeaa2a6d47e77..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cu +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// This file serve as a simple example for adding a tunable op to onnxruntime. - -#if USE_CUDA -#include -#include -#elif USE_ROCM -#include -#endif -#include - -#include - -#if USE_CUDA -#include "core/providers/cuda/tunable/cuda_tunable.h" -#elif USE_ROCM -#include "core/providers/rocm/tunable/rocm_tunable.h" -#endif -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" -#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh" - -namespace py = pybind11; - -namespace onnxruntime { - -// ##################################################################################################################### -// In practice, VectorAddParam, VectorAddOp and VectorAddTunableOp should be tightly integrated to onnxruntime. -// We place them here purely for demo purpose. -// ##################################################################################################################### - -// Extend the OpParams so that all specializations have the same parameter passing interface -template -struct VectorAddParams : -#if USE_CUDA - cuda::tunable::OpParams -#elif USE_ROCM - rocm::tunable::OpParams -#endif -{ - std::string Signature() const override { return std::to_string(n); } - - T* x; - T* y; - T* z; - int n; -}; - -// Wrap the kernel function, so that we have a unified launch interface. If the kernel has state, the state can also -// be managed at this level via a functor -template -Status VectorAddOp(const VectorAddParams* params) { - return LaunchVectorAdd( - params->StreamHandle(), - params->x, - params->y, - params->z, - params->n); -} - -#define ADD_OP(threads_per_block) \ - this->RegisterOp(VectorAddOp); \ - this->RegisterOp(VectorAddOp); \ - this->RegisterOp(VectorAddOp); \ - this->RegisterOp(VectorAddOp); - -// A Tunable VectorAddOp is a collection of non-tunable VectorAddOps implementations that have variable performance -// characteristics. Those implementations may be put into a C++ container for tuner to select. -template -class VectorAddTunableOp : -#if USE_CUDA - public cuda::tunable::TunableOp> -#elif USE_ROCM - public rocm::tunable::TunableOp> -#endif -{ - public: - VectorAddTunableOp() { - ADD_OP(64); - ADD_OP(128); - ADD_OP(192); - ADD_OP(256); - ADD_OP(320); - ADD_OP(384); - ADD_OP(448); - ADD_OP(512); - } -}; - -#undef ADD_OP - -// ##################################################################################################################### -// Following code just wraps our kernel implementation and expose them as python interface. This is the code that -// should be in the kernel_explorer directory. -// ##################################################################################################################### - -template -class VectorAdd : public IKernelExplorer { - public: - VectorAdd(DeviceArray& x, DeviceArray& y, DeviceArray& z, int n) { - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.x = static_cast(x.ptr()); - params_.y = static_cast(y.ptr()); - params_.z = static_cast(z.ptr()); - params_.n = n; - } - - void Run() override { - ORT_THROW_IF_ERROR((VectorAddOp(¶ms_))); - } - - private: - // A VectorAddOp is a callable that can process const VectorAddParams* - using ParamsT = VectorAddParams; - ParamsT params_{}; -}; - -template -class VectorAddTunable : public IKernelExplorer { - public: - VectorAddTunable(DeviceArray& x, DeviceArray& y, DeviceArray& z, int n) { - params_.tuning_ctx = TuningContext(); - params_.stream = Stream(); - params_.x = static_cast(x.ptr()); - params_.y = static_cast(y.ptr()); - params_.z = static_cast(z.ptr()); - params_.n = n; - - params_.TuningContext()->EnableTunableOpAndTuning(); - } - - void Run() override { - ORT_THROW_IF_ERROR(impl_(¶ms_)); - } - - private: - using ParamsT = VectorAddParams; - ParamsT params_; - - // tunable is stateful, store it as an instance - VectorAddTunableOp impl_; -}; - -#define REGISTER_OP(name, type, threads_per_block, vec_size) \ - py::class_>(m, #name"_"#type"_"#threads_per_block"_"#vec_size) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run); - -#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \ - REGISTER_OP(name, type, threads_per_block, 1) \ - REGISTER_OP(name, type, threads_per_block, 2) \ - REGISTER_OP(name, type, threads_per_block, 4) \ - REGISTER_OP(name, type, threads_per_block, 8) - -#define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(name, type) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 64) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 128) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 192) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 256) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 448) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 512) - -#define REGISTER_TUNABLE_OP(type) \ - py::class_>(m, "VectorAdd_" #type "_Tunable") \ - .def(py::init()) \ - .def("SetRepeats", &VectorAddTunable::SetRepeats) \ - .def("Profile", &VectorAddTunable::Profile) \ - .def("Run", &VectorAddTunable::Run); - -KE_REGISTER(m) { - REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(VectorAdd, half); - REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(VectorAdd, float); - - REGISTER_TUNABLE_OP(half); - REGISTER_TUNABLE_OP(float) -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_kernel.cuh b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_kernel.cuh deleted file mode 100644 index a2aceaa9a15c5..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_kernel.cuh +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#if USE_CUDA -#include -#include "core/providers/cuda/cu_inc/common.cuh" -#include "core/providers/cuda/tunable/util.h" -#elif USE_ROCM -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/tunable/util.h" -#endif - -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -#if USE_CUDA -using onnxruntime::cuda::aligned_vector; -using onnxruntime::cuda::CeilDiv; -using StreamT = cudaStream_t; -#elif USE_ROCM -using onnxruntime::rocm::aligned_vector; -using onnxruntime::rocm::CeilDiv; -using StreamT = hipStream_t; -#endif - -namespace onnxruntime { - -template -__global__ void VectorAddKernel(const T* __restrict__ x, - const T* __restrict__ y, - T* __restrict__ z, int n) { - int i = blockDim.x * blockIdx.x + threadIdx.x; - using LoadT = aligned_vector; - - if (VecSize * i + VecSize - 1 < n) { - T x_vec[VecSize]; - LoadT* x_load = reinterpret_cast(&x_vec); - *x_load = *reinterpret_cast(&x[VecSize * i]); - - T y_vec[VecSize]; - LoadT* y_load = reinterpret_cast(&y_vec); - *y_load = *reinterpret_cast(&y[VecSize * i]); - - T z_vec[VecSize]; - -#pragma unroll - for (int j = 0; j < VecSize; j++) { - z_vec[j] = x_vec[j] + y_vec[j]; - } - - *(reinterpret_cast(&z[VecSize * i])) = *reinterpret_cast(&z_vec[0]); - } - - if (i == 0) { - int tail_size = n % VecSize; - for (int j = n - 1; j >= n - tail_size; j--) { - z[j] = x[j] + y[j]; - } - } -} - -template -Status LaunchVectorAdd(StreamT stream, const T* x, const T* y, T* z, int n) { - VectorAddKernel - <<>>(x, y, z, n); -#if USE_CUDA - return CUDA_CALL(cudaGetLastError()); -#elif USE_ROCM - return HIP_CALL(hipGetLastError()); -#endif -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py deleted file mode 100644 index 8edf55f68c11f..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from dataclasses import dataclass - -import kernel_explorer as ke -import numpy as np -import pytest -from utils import dtype_to_bytes - - -def dtype_to_funcs(dtype): - type_map = { - "float16": list(filter(lambda x: "VectorAdd_half" in x, dir(ke))), - "float32": list(filter(lambda x: "VectorAdd_float" in x, dir(ke))), - } - return type_map[dtype] - - -def run_vector_add(size, dtype, func): - np.random.seed(0) - x = np.random.rand(size).astype(dtype) - y = np.random.rand(size).astype(dtype) - z = np.random.rand(size).astype(dtype) - - x_d = ke.DeviceArray(x) - y_d = ke.DeviceArray(y) - z_d = ke.DeviceArray(z) - f = getattr(ke, func) - my_op = f(x_d, y_d, z_d, size) - my_op.Run() - z_d.UpdateHostNumpyArray() - - z_ref = x + y - np.testing.assert_allclose(z_ref, z) - - -dtypes = ["float16", "float32"] - - -@pytest.mark.parametrize("size", [1, 3, 4, 16, 124, 125, 126, 127, 128, 129, 130, 131, 132, 1024]) -@pytest.mark.parametrize("dtype", dtypes) -@ke.dispatchable -def test_vector_add(size, dtype): - for dtype in dtypes: - for f in dtype_to_funcs(dtype): - run_vector_add(size, dtype, f) - - -@dataclass -class VectorAddMetric(ke.BandwidthMetric): - size: int - - def report(self): - return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} size={self.size:<4} {self.name}" - - -@ke.dispatchable(pattern_arg=2) -def profile_vector_add_func(size, dtype, func): - np.random.seed(0) - x = np.random.rand(size).astype(dtype) - y = np.random.rand(size).astype(dtype) - z = np.random.rand(size).astype(dtype) - - x_d = ke.DeviceArray(x) - y_d = ke.DeviceArray(y) - z_d = ke.DeviceArray(z) - f = getattr(ke, func) - my_op = f(x_d, y_d, z_d, size) - duration_ms = my_op.Profile() - total_bytes = size * 3 * (dtype_to_bytes(dtype)) - - ke.report(VectorAddMetric(func, dtype, duration_ms, total_bytes, size)) - - -@ke.dispatchable -def profile_with_args(size, dtype): - with ke.benchmark(): - for func in dtype_to_funcs(dtype): - profile_vector_add_func(size, dtype, func) - - -def profile(): - sizes = [10000, 100000, 1000000, 10000000] - for dt in dtypes: - for s in sizes: - profile_with_args(s, dt) - print() - - -if __name__ == "__main__": - parser = ke.get_argument_parser() - group = parser.add_argument_group() - group.add_argument("size", type=int) - group.add_argument("dtype", choices=dtypes) - - if not ke.has_args(): - profile() - else: - args = parser.parse_args() - args.dispatch(args.size, args.dtype) diff --git a/onnxruntime/python/tools/kernel_explorer/version_script.lds b/onnxruntime/python/tools/kernel_explorer/version_script.lds deleted file mode 100644 index a54293c2a09f6..0000000000000 --- a/onnxruntime/python/tools/kernel_explorer/version_script.lds +++ /dev/null @@ -1,6 +0,0 @@ -# Export everything for building with kernel explorer, -# so that we can reuse all those utilities functions -VERS_0.0 { - global: - **; -}; diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index b6b62dc3bb3a1..ac696ff3788aa 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -158,6 +158,14 @@ def parse_arguments(argv=None): ) conversion_args.set_defaults(no_beam_search_op=False) + conversion_args.add_argument( + "--use_decoder_masked_mha", + required=False, + action="store_true", + help="Use DecoderMaskedMultiHeadAttention kernel for improved performance. This is currently an experimental feature.", + ) + conversion_args.set_defaults(use_decoder_masked_mha=False) + ############################################################# # Optional inputs for Whisper # (listed below in the order that WhisperBeamSearch expects) @@ -305,8 +313,13 @@ def parse_arguments(argv=None): quant_args.set_defaults(quantize_reduce_range=False) args = parser.parse_args(argv) + + # Collect cross QKs if either flag is enabled args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk + # FP32 CPU can be supported here once the DMMHA CPU kernel bugs are fixed + args.use_decoder_masked_mha = args.use_decoder_masked_mha and args.provider == "cuda" + return args @@ -323,6 +336,7 @@ def export_onnx_models( use_forced_decoder_ids: bool = False, merge_encoder_and_decoder_init: bool = True, no_beam_search_op: bool = False, + use_decoder_masked_mha: bool = False, output_qk: bool = False, overwrite: bool = False, use_int32_inputs: bool = True, @@ -402,6 +416,7 @@ def export_onnx_models( provider=provider, is_decoder=(name == "decoder"), no_beam_search_op=no_beam_search_op, + use_decoder_masked_mha=use_decoder_masked_mha, output_qk=output_qk, ) # Remove old ONNX model and old data file @@ -474,6 +489,7 @@ def main(argv=None): args.use_forced_decoder_ids, not args.separate_encoder_and_decoder_init, args.no_beam_search_op, + args.use_decoder_masked_mha, args.output_cross_qk, args.overwrite, not args.use_int64_inputs, @@ -541,6 +557,7 @@ def main(argv=None): args.model_name_or_path, args.provider, args.separate_encoder_and_decoder_init, + args.use_decoder_masked_mha, args.output_cross_qk, next(iter(filter(lambda path: "encoder" in path, output_paths))), next(iter(filter(lambda path: "decoder" in path, output_paths))), diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 08118ccb551eb..a72185a2d9213 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -73,6 +73,7 @@ def save_processing( model_name_or_path: str, provider: str, separate_encoder_and_decoder_init: bool, + use_decoder_masked_mha: bool, output_qk: bool, encoder_path: str, decoder_path: str, @@ -596,7 +597,7 @@ def save_processing( "no_repeat_ngram_size": 0, "num_beams": 1, "num_return_sequences": 1, - "past_present_share_buffer": provider == "cuda", + "past_present_share_buffer": use_decoder_masked_mha, "repetition_penalty": 1.0, "temperature": 1.0, "top_k": 1, @@ -604,16 +605,13 @@ def save_processing( }, } - # Requirements for the DMMHA kernel (which is currently - # enabled for CUDA only): + # Requirements for the DMMHA kernel: # - Buffer sharing = true # - New input: past_sequence_length # - New input: cache_indirection - # Otherwise, buffer sharing should be false and the new inputs - # should not be added for beam search to work in ORT GenAI. - - if provider == "cuda": - # Add inputs for DMMHA kernel + # Otherwise, buffer sharing should be false and the new inputs should not be added + # for beam search to work in ORT GenAI. + if use_decoder_masked_mha: genai_config["model"]["decoder"]["inputs"].update( { "past_sequence_length": "past_sequence_length", @@ -771,6 +769,7 @@ def optimize_onnx( provider: str = "cpu", is_decoder: bool = False, no_beam_search_op: bool = False, + use_decoder_masked_mha: bool = False, output_qk: bool = False, ): """Optimize ONNX model with an option to convert it to use mixed precision.""" @@ -794,7 +793,7 @@ def optimize_onnx( # Add `past_sequence_length`, `cache_indirection`, and `output_qk` to `MultiHeadAttention` ops if is_decoder and no_beam_search_op: - if provider == "cuda": # FP32 CPU can be supported here once the DMMHA CPU kernel bugs are fixed + if use_decoder_masked_mha: # FP16 CUDA, FP32 CUDA, and FP32 CPU use the `DecoderMaskedMultiHeadAttention` kernel # via `MultiHeadAttention`, which requires the `past_sequence_length` and # `cache_indirection` inputs diff --git a/onnxruntime/python/tools/transformers/requirements.txt b/onnxruntime/python/tools/transformers/requirements.txt index ce1380a757729..005816e43813d 100644 --- a/onnxruntime/python/tools/transformers/requirements.txt +++ b/onnxruntime/python/tools/transformers/requirements.txt @@ -11,4 +11,4 @@ sentencepiece pillow # please follow https://pytorch.org/ to install PyTorch for your OS -torch >= 1.13.1 +torch >= 2.6.0 diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc new file mode 100644 index 0000000000000..d9418e3e3156d --- /dev/null +++ b/onnxruntime/test/autoep/library/ep.cc @@ -0,0 +1,502 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ep_factory.h" + +/// +/// Example implementation of ONNX Mul. Does not handle many things like broadcasting. +/// +struct MulKernel { + MulKernel(const OrtApi& ort_api, const OrtLogger& logger, + const std::unordered_map& float_initializers, + std::string input0_name, std::string input1_name) + : ort_api(ort_api), + logger(logger), + float_initializers(float_initializers), + input0_name(input0_name), + input1_name(input1_name) {} + + const FloatInitializer* TryGetSavedInitializer(const std::string& name) const { + auto iter = float_initializers.find(name); + return iter != float_initializers.end() ? &iter->second : nullptr; + } + + OrtStatus* GetInputDataAndShape(OrtKernelContext* kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const { + const OrtValue* input = nullptr; + RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_context, index, &input)); + + OrtTensorTypeAndShapeInfo* type_shape = nullptr; + DeferOrtRelease release_type(&type_shape, ort_api.ReleaseTensorTypeAndShapeInfo); + + RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input, &type_shape)); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); + RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 inputs"); + + size_t num_elems = 0; + RETURN_IF_ERROR(ort_api.GetTensorShapeElementCount(type_shape, &num_elems)); + + size_t num_dims = 0; + RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape, &num_dims)); + + shape.resize(num_dims, 0); + RETURN_IF_ERROR(ort_api.GetDimensions(type_shape, shape.data(), shape.size())); + + const void* raw_data = nullptr; + RETURN_IF_ERROR(ort_api.GetTensorData(input, &raw_data)); + + const float* float_data = static_cast(raw_data); + data = gsl::span(float_data, num_elems); + return nullptr; + } + + OrtStatus* Compute(OrtKernelContext* kernel_context) { + RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); + gsl::span input0; + gsl::span input1; + std::vector shape0; + std::vector shape1; + + size_t num_inputs = 0; + RETURN_IF_ERROR(ort_api.KernelContext_GetInputCount(kernel_context, &num_inputs)); + + if (num_inputs == 2) { + // Both inputs are non-constant. Get them from ORT's KernelContext. + RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input0, shape0)); + RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 1, input1, shape1)); + } else if (num_inputs == 1) { + // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. + // Get the constant input from the initializers saved by the EP. + // Refer to "NodeFusionOptions_DropConstantInitializers()". + + if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { + RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input1, shape1)); + input0 = gsl::span(const_input0->data); + shape0 = const_input0->shape; + } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { + RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input0, shape0)); + input1 = gsl::span(const_input1->data); + shape1 = const_input1->shape; + } + } else { + // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) + // are disabled. + const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); + const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); + RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, + "Expected 2 initializer inputs to be saved by EP"); + + input0 = gsl::span(const_input0->data); + input1 = gsl::span(const_input1->data); + shape0 = const_input0->shape; + shape1 = const_input1->shape; + } + + RETURN_IF(shape0 != shape1, ort_api, "Expected same dimensions for both inputs"); // No broadcasting. + + size_t num_outputs = 0; + RETURN_IF_ERROR(ort_api.KernelContext_GetOutputCount(kernel_context, &num_outputs)); + RETURN_IF(num_outputs != 1, ort_api, "Expected 1 output for MulKernel"); + + OrtValue* output = nullptr; + float* output_data = nullptr; + RETURN_IF_ERROR(ort_api.KernelContext_GetOutput(kernel_context, 0, shape0.data(), shape0.size(), &output)); + RETURN_IF_ERROR(ort_api.GetTensorMutableData(output, reinterpret_cast(&output_data))); + + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] * input1[i]; + } + + return nullptr; + } + + const OrtApi& ort_api; + const OrtLogger& logger; + const std::unordered_map& float_initializers; + std::string input0_name; + std::string input1_name; +}; + +/// +/// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. +/// +struct ExampleNodeComputeInfo : OrtNodeComputeInfo { + explicit ExampleNodeComputeInfo(ExampleEp& ep); + + static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, + OrtNodeComputeContext* compute_context, + void** compute_state); + static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context); + static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state); + + ExampleEp& ep; +}; + +ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger) + : OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized + ApiPtrs{static_cast(factory)}, + factory_{factory}, + name_{name}, + config_{config}, + logger_{logger} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + + // Initialize the execution provider's function table + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + Compile = CompileImpl; + ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; + + auto status = ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + ("ExampleEp has been created with name " + name_).c_str(), + ORT_FILE, __LINE__, __FUNCTION__); + // ignore status for now + (void)status; +} + +ExampleEp::~ExampleEp() = default; + +/*static*/ +const char* ORT_API_CALL ExampleEp ::GetNameImpl(const OrtEp* this_ptr) noexcept { + const auto* ep = static_cast(this_ptr); + return ep->name_.c_str(); +} + +OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph) { + OrtArrayOfConstObjects* initializers = nullptr; + DeferOrtRelease release_initializers(&initializers, ort_api.ReleaseArrayOfConstObjects); + size_t num_initializers = 0; + + RETURN_IF_ERROR(ort_api.Graph_GetInitializers(graph, &initializers)); + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetSize(initializers, &num_initializers)); + + for (size_t i = 0; i < num_initializers; ++i) { + const OrtValueInfo* initializer = nullptr; + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(initializers, i, + reinterpret_cast(&initializer))); + + bool is_constant = false; + RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(initializer, &is_constant)); + + if (is_constant) { + const char* name = nullptr; + const OrtValue* value = nullptr; + OrtTensorTypeAndShapeInfo* type_shape = nullptr; + DeferOrtRelease release_type(&type_shape, ort_api.ReleaseTensorTypeAndShapeInfo); + size_t num_elems = 0; + + RETURN_IF_ERROR(ort_api.GetValueInfoName(initializer, &name)); + RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer, &value)); + RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(value, &type_shape)); + RETURN_IF_ERROR(ort_api.GetTensorShapeElementCount(type_shape, &num_elems)); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); + RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 initializers"); + + size_t num_dims = 0; + RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape, &num_dims)); + + std::vector dims(num_dims, 0); + RETURN_IF_ERROR(ort_api.GetDimensions(type_shape, dims.data(), dims.size())); + + const float* data = nullptr; + RETURN_IF_ERROR(ort_api.GetTensorMutableData(const_cast(value), (void**)&data)); + + FloatInitializer ep_initializer = {std::move(dims), std::vector(data, data + num_elems)}; + float_initializers_.emplace(name, std::move(ep_initializer)); + } + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) { + ExampleEp* ep = static_cast(this_ptr); + + OrtArrayOfConstObjects* nodes_array = nullptr; + DeferOrtRelease release_nodes_array(&nodes_array, ep->ort_api.ReleaseArrayOfConstObjects); + + size_t num_nodes = 0; + + RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graph, &nodes_array)); + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); + + if (num_nodes == 0) { + return nullptr; // No nodes to process + } + + const void* const* nodes_data = nullptr; + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetData(nodes_array, &nodes_data)); + auto nodes_span = gsl::span(reinterpret_cast(nodes_data), num_nodes); + + std::vector supported_nodes; + + for (const OrtNode* node : nodes_span) { + const char* op_type = nullptr; + RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type)); + + if (std::strncmp(op_type, "Mul", 4) == 0) { + // Check that Mul has inputs/output of type float + OrtArrayOfConstObjects* inputs_array = nullptr; + OrtArrayOfConstObjects* outputs_array = nullptr; + DeferOrtRelease release_inputs(&inputs_array, ep->ort_api.ReleaseArrayOfConstObjects); + DeferOrtRelease release_outputs(&outputs_array, ep->ort_api.ReleaseArrayOfConstObjects); + + RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(node, &inputs_array)); + RETURN_IF_ERROR(ep->ort_api.Node_GetOutputs(node, &outputs_array)); + + size_t num_inputs = 0; + size_t num_outputs = 0; + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(inputs_array, &num_inputs)); + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(outputs_array, &num_outputs)); + RETURN_IF(num_inputs != 2 || num_outputs != 1, ep->ort_api, "Mul should have 2 inputs and 1 output"); + + const void* const* inputs_data = nullptr; + const void* const* outputs_data = nullptr; + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetData(inputs_array, &inputs_data)); + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetData(outputs_array, &outputs_data)); + + std::array is_float = {false, false, false}; + RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, static_cast(inputs_data[0]), is_float[0])); + RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, static_cast(inputs_data[1]), is_float[1])); + RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, static_cast(outputs_data[0]), is_float[2])); + if (!is_float[0] || !is_float[1] || !is_float[2]) { + continue; // Input or output is not of type float + } + + supported_nodes.push_back(node); // Only support a single Mul for now. + break; + } + } + + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers + // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. + // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use + // during inference. + node_fusion_options.drop_constant_initializers = true; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(), + supported_nodes.size(), &node_fusion_options)); + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes) { + ExampleEp* ep = static_cast(this_ptr); + const OrtApi& ort_api = ep->ort_api; + + if (count != 1) { + return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single graph"); + } + + // In GetCapability(), this EP specified that it doesn't need ORT to provide constant initializers during inference. + // So, this EP saves constant initializers so that they're available during inference, but an actual EP + // implementation could transfer the weights to device memory. + ep->SaveConstantInitializers(graphs[0]); + + OrtArrayOfConstObjects* nodes_array = nullptr; + DeferOrtRelease release_nodes(&nodes_array, ort_api.ReleaseArrayOfConstObjects); + size_t num_nodes = 0; + + RETURN_IF_ERROR(ort_api.Graph_GetNodes(graphs[0], &nodes_array)); + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); + + if (num_nodes != 1) { + return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); + } + + const OrtNode* node_to_compile = nullptr; + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(nodes_array, 0, + reinterpret_cast(&node_to_compile))); + + const char* node_op_type = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetOperatorType(node_to_compile, &node_op_type)); + + if (std::strncmp(node_op_type, "Mul", 4) != 0) { + return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); + } + + // Now we know we're compiling a single Mul node. Create a computation kernel. + OrtArrayOfConstObjects* inputs = nullptr; + DeferOrtRelease release_inputs(&inputs, ort_api.ReleaseArrayOfConstObjects); + + RETURN_IF_ERROR(ort_api.Node_GetInputs(node_to_compile, &inputs)); + const OrtValueInfo* input0 = nullptr; + const OrtValueInfo* input1 = nullptr; + + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(inputs, 0, reinterpret_cast(&input0))); + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(inputs, 1, reinterpret_cast(&input1))); + + const char* input0_name = nullptr; + const char* input1_name = nullptr; + RETURN_IF_ERROR(ort_api.GetValueInfoName(input0, &input0_name)); + RETURN_IF_ERROR(ort_api.GetValueInfoName(input1, &input1_name)); + + // Associate the name of the fused node with our MulKernel. + const char* fused_node_name = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetName(fused_nodes[0], &fused_node_name)); + + ep->kernels_.emplace(std::string(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, + ep->float_initializers_, + input0_name, input1_name)); + // Update the OrtNodeComputeInfo associated with the graph. + auto node_compute_info = std::make_unique(*ep); + node_compute_infos[0] = node_compute_info.release(); + + // Create EpContext nodes for the fused nodes we compiled. + if (ep->config_.enable_ep_context) { + assert(ep_context_nodes != nullptr); + RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), + gsl::span(ep_context_nodes, count))); + } + + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + size_t num_node_compute_infos) { + (void)this_ptr; + for (size_t i = 0; i < num_node_compute_infos; i++) { + delete node_compute_infos[i]; + } +} + +// Creates EPContext nodes from the given fused nodes. +// This is an example implementation that can be used to generate an EPContext model. However, this example EP +// cannot currently run the EPContext model. +OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes, + /*out*/ gsl::span ep_context_nodes) { + assert(fused_nodes.size() == ep_context_nodes.size()); + + // Helper to collect input or output names from an array of OrtValueInfo instances. + auto collect_input_output_names = [&](const OrtArrayOfConstObjects& value_infos, + std::vector& result) -> OrtStatus* { + size_t num_values = 0; + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetSize(&value_infos, &num_values)); + + std::vector value_names(num_values, nullptr); + + for (size_t i = 0; i < num_values; i++) { + const void* value_info = nullptr; // Is a const OrtValueInfo* + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(&value_infos, i, &value_info)); + RETURN_IF_ERROR(ort_api.GetValueInfoName(static_cast(value_info), &value_names[i])); + } + + result = std::move(value_names); + return nullptr; + }; + + // Create an "EPContext" node for every fused node. + for (size_t i = 0; i < fused_nodes.size(); ++i) { + const OrtNode* fused_node = fused_nodes[i]; + const char* fused_node_name = nullptr; + + RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &fused_node_name)); + + OrtArrayOfConstObjects* fused_node_inputs = nullptr; + OrtArrayOfConstObjects* fused_node_outputs = nullptr; + DeferOrtRelease defer_release0(&fused_node_inputs, ort_api.ReleaseArrayOfConstObjects); + DeferOrtRelease defer_release1(&fused_node_outputs, ort_api.ReleaseArrayOfConstObjects); + + RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, &fused_node_inputs)); + RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, &fused_node_outputs)); + + std::vector input_names; + std::vector output_names; + + RETURN_IF_ERROR(collect_input_output_names(*fused_node_inputs, /*out*/ input_names)); + RETURN_IF_ERROR(collect_input_output_names(*fused_node_outputs, /*out*/ output_names)); + + int64_t is_main_context = (i == 0); + int64_t embed_mode = 1; + + // Create node attributes. The CreateNode() function copies the attributes, so we have to release them. + std::array attributes = {}; + DeferOrtRelease defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr); + + RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", "binary_data", 1, ORT_OP_ATTR_STRING, &attributes[0])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("main_context", &is_main_context, 1, ORT_OP_ATTR_INT, &attributes[1])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT, &attributes[2])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_sdk_version", "1", 1, ORT_OP_ATTR_STRING, &attributes[3])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("partition_name", fused_node_name, 1, ORT_OP_ATTR_STRING, &attributes[4])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("source", this->name_.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[5])); + + RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name, + input_names.data(), input_names.size(), + output_names.data(), output_names.size(), + attributes.data(), attributes.size(), + &ep_context_nodes[i])); + } + + return nullptr; +} +// +// Implementation of ExampleNodeComputeInfo +// +ExampleNodeComputeInfo::ExampleNodeComputeInfo(ExampleEp& ep) : ep(ep) { + ort_version_supported = ORT_API_VERSION; + CreateState = CreateStateImpl; + Compute = ComputeImpl; + ReleaseState = ReleaseStateImpl; +} + +OrtStatus* ExampleNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, + OrtNodeComputeContext* compute_context, + void** compute_state) { + auto* node_compute_info = static_cast(this_ptr); + ExampleEp& ep = node_compute_info->ep; + + std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); + auto kernel_it = ep.Kernels().find(fused_node_name); + if (kernel_it == ep.Kernels().end()) { + std::string message = "Unable to get kernel for fused node with name " + fused_node_name; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); + } + + MulKernel& kernel = *kernel_it->second; + *compute_state = &kernel; + return nullptr; +} + +OrtStatus* ExampleNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context) { + (void)this_ptr; + MulKernel& kernel = *reinterpret_cast(compute_state); + return kernel.Compute(kernel_context); +} + +void ExampleNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { + (void)this_ptr; + MulKernel& kernel = *reinterpret_cast(compute_state); + (void)kernel; + // Do nothing for this example. +} diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/ep.h new file mode 100644 index 0000000000000..b8c63f39438ba --- /dev/null +++ b/onnxruntime/test/autoep/library/ep.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "example_plugin_ep_utils.h" + +class ExampleEpFactory; +struct MulKernel; + +/// +/// Example EP that can compile a single Mul operator. +/// +class ExampleEp : public OrtEp, public ApiPtrs { + public: + struct Config { + bool enable_ep_context = false; + // Other EP configs (typically extracted from OrtSessionOptions or OrtHardwareDevice(s)) + }; + + ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger); + + ~ExampleEp(); + + std::unordered_map>& Kernels() { + return kernels_; + } + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info); + static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes); + static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + size_t num_node_compute_infos); + + OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, + /*out*/ gsl::span ep_context_nodes); + + OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph); + + ExampleEpFactory& factory_; + std::string name_; + Config config_{}; + const OrtLogger& logger_; + std::unordered_map> kernels_; + std::unordered_map float_initializers_; +}; diff --git a/onnxruntime/test/autoep/library/ep_allocator.h b/onnxruntime/test/autoep/library/ep_allocator.h new file mode 100644 index 0000000000000..5e6f81fd0aa9e --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_allocator.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "example_plugin_ep_utils.h" + +struct CustomAllocator : OrtAllocator { + CustomAllocator(const OrtMemoryInfo* mem_info) : memory_info{mem_info} { + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena + } + + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* /*this_*/, size_t size) { + // CustomAllocator& impl = *static_cast(this_); + return malloc(size); + } + + /// Free a block of memory previously allocated with OrtAllocator::Alloc + static void ORT_API_CALL FreeImpl(struct OrtAllocator* /*this_*/, void* p) { + return free(p); + } + + /// Return a pointer to an ::OrtMemoryInfo that describes this allocator + static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { + const CustomAllocator& impl = *static_cast(this_); + return impl.memory_info; + } + + private: + const OrtMemoryInfo* memory_info; +}; diff --git a/onnxruntime/test/autoep/library/ep_data_transfer.cc b/onnxruntime/test/autoep/library/ep_data_transfer.cc new file mode 100644 index 0000000000000..48f97fe88ec44 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_data_transfer.cc @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_data_transfer.h" + +#include +#include + +/*static*/ +bool ORT_API_CALL ExampleDataTransfer::CanCopyImpl(void* this_ptr, + const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept { + static constexpr uint32_t VendorId = 0xBE57; // Example vendor ID for demonstration purposes. + + auto& impl = *static_cast(this_ptr); + bool src_is_our_device = impl.ep_api.MemoryDevice_AreEqual(src_memory_device, impl.device_mem_info); + bool dst_is_our_device = impl.ep_api.MemoryDevice_AreEqual(dst_memory_device, impl.device_mem_info); + + if (src_is_our_device && dst_is_our_device) { + return true; + } + + // implementation should check if the copy is possible, which may require checking the device type, the memory type + // and the vendor and device IDs as needed. + OrtMemoryInfoDeviceType src_device_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device); + OrtMemoryInfoDeviceType dst_device_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device); + // OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_memory_device); + // OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_memory_device); + // uint32_t src_device_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device); + // uint32_t dst_device_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device); + // uint32_t src_device_device_id = impl.ep_api.MemoryDevice_GetDeviceId(src_memory_device); + // uint32_t dst_device_device_id = impl.ep_api.MemoryDevice_GetDeviceId(dst_memory_device); + + if (src_is_our_device) { + // check device type and vendor to see if compatible + return (dst_device_type == OrtMemoryInfoDeviceType_CPU); + } + + if (dst_is_our_device) { + // check device type and vendor to see if compatible + return (src_device_type == OrtMemoryInfoDeviceType_CPU); + } + + return false; +} + +// function to copy one or more tensors. +// implementation can optionally use async copy if a stream is available for the input. +/*static*/ +OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(void* this_ptr, + const OrtValue** src_tensors_ptr, + OrtValue** dst_tensors_ptr, + OrtSyncStream** streams_ptr, + size_t num_tensors) noexcept { + auto& impl = *static_cast(this_ptr); + + auto src_tensors = gsl::make_span(src_tensors_ptr, num_tensors); + auto dst_tensors = gsl::make_span(dst_tensors_ptr, num_tensors); + auto streams = gsl::make_span(streams_ptr, num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + // NOTE: Stream support will be a separate PR. ignore teh streams_ptr values for now + + const OrtMemoryDevice* src_device = nullptr; + const OrtMemoryDevice* dst_device = nullptr; + RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(src_tensors[i], &src_device)); + RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(dst_tensors[i], &dst_device)); + + OrtMemoryInfoDeviceType src_device_type = impl.ep_api.MemoryDevice_GetDeviceType(src_device); + OrtMemoryInfoDeviceType dst_device_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_device); + + // OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_device); + // OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_device); + // bool copy_involves_pinned_memory = src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE || + // dst_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE; + + const void* src_data = nullptr; + void* dst_data = nullptr; + RETURN_IF_ERROR(impl.ort_api.GetTensorData(src_tensors[i], &src_data)); + RETURN_IF_ERROR(impl.ort_api.GetTensorMutableData(dst_tensors[i], &dst_data)); + + if (dst_device_type == OrtMemoryInfoDeviceType_GPU) { + if (src_device_type == OrtMemoryInfoDeviceType_GPU) { + // GPU -> GPU + } else { + // CPU -> GPU + } + } else if (src_device_type == OrtMemoryInfoDeviceType_GPU) { + // GPU -> CPU + } else { + // CPU -> CPU involves copy to/from pinned memory and a synchronize may be required first + } + } + + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleDataTransfer::ReleaseImpl(void* /*this_ptr*/) noexcept { + // In our setup the factory owns a shared ExampleDataTransfer instance so it will do the cleanup, and we ignore + // the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h) + // + // If you create a new instance on each call to OrtEpFactory::CreateDataTransfer you call `delete` here + // delete static_cast(this_ptr); +} diff --git a/onnxruntime/test/autoep/library/ep_data_transfer.h b/onnxruntime/test/autoep/library/ep_data_transfer.h new file mode 100644 index 0000000000000..d73b9e457b844 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_data_transfer.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "example_plugin_ep_utils.h" + +struct ExampleDataTransfer : OrtDataTransferImpl, ApiPtrs { + ExampleDataTransfer(ApiPtrs api_ptrs, + const OrtMemoryDevice* device_mem_info_, + const OrtMemoryDevice* shared_mem_info_ = nullptr) + : ApiPtrs(api_ptrs), device_mem_info{device_mem_info_}, shared_mem_info{shared_mem_info_} { + CanCopy = CanCopyImpl; + CopyTensors = CopyTensorsImpl; + Release = ReleaseImpl; + } + + static bool ORT_API_CALL CanCopyImpl(void* this_ptr, + const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept; + + // function to copy one or more tensors. + // implementation can optionally use async copy if a stream is available for the input. + static OrtStatus* ORT_API_CALL CopyTensorsImpl(void* this_ptr, + const OrtValue** src_tensors_ptr, + OrtValue** dst_tensors_ptr, + OrtSyncStream** streams_ptr, + size_t num_tensors) noexcept; + static void ORT_API_CALL ReleaseImpl(void* this_ptr) noexcept; + + private: + const OrtMemoryDevice* device_mem_info; + const OrtMemoryDevice* shared_mem_info; +}; diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc new file mode 100644 index 0000000000000..c2fa5ec88a0d8 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -0,0 +1,248 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_factory.h" + +#include + +#include "ep.h" +#include "ep_allocator.h" +#include "ep_data_transfer.h" + +ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) + : ApiPtrs(apis), ep_name_{ep_name} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + + GetSupportedDevices = GetSupportedDevicesImpl; + + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + + CreateDataTransfer = CreateDataTransferImpl; + + // for the sake of this example we specify a CPU allocator with no arena and 1K alignment (arbitrary) + // as well as GPU and GPU shared memory. the actual EP implementation would typically define two at most for a + // device (one for device memory and one for shared memory for data transfer between device and CPU) + + // setup the OrtMemoryInfo instances required by the EP. + OrtMemoryInfo* mem_info = nullptr; + auto* status = ort_api.CreateMemoryInfo_V2("ExampleEP CPU", OrtMemoryInfoDeviceType_CPU, + /*vendor*/ 0xBE57, /* device_id */ 0, + OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 1024, + OrtAllocatorType::OrtDeviceAllocator, // no arena + &mem_info); + assert(status == nullptr); // should never fail. + + cpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); + + // + // GPU allocator OrtMemoryInfo for example purposes + mem_info = nullptr; + status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ 0xBE57, /* device_id */ 0, + OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator, + &mem_info); + assert(status == nullptr); // should never fail. + default_gpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); + + // HOST_ACCESSIBLE memory should use the non-CPU device type + mem_info = nullptr; + status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU pinned", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ 0xBE57, /* device_id */ 0, + OrtDeviceMemoryType_HOST_ACCESSIBLE, + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator, + &mem_info); + assert(status == nullptr); // should never fail. + host_accessible_gpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); + + // if we were to use GPU we'd create it like this + data_transfer_impl_ = std::make_unique( + apis, + ep_api.MemoryInfo_GetMemoryDevice(default_gpu_memory_info_.get()), // device memory + ep_api.MemoryInfo_GetMemoryDevice(host_accessible_gpu_memory_info_.get()) // shared memory + ); + + data_transfer_impl_.reset(); // but we're CPU only so we return nullptr for the IDataTransfer. +} + +/*static*/ +const char* ORT_API_CALL ExampleEpFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_name_.c_str(); +} + +/*static*/ +const char* ORT_API_CALL ExampleEpFactory::GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_.c_str(); +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + // C API + const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + // these can be returned as nullptr if you have nothing to add. + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api.CreateKeyValuePairs(&ep_metadata); + factory->ort_api.CreateKeyValuePairs(&ep_options); + + // random example using made up values + factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1"); + factory->ort_api.AddKeyValuePair(ep_options, "run_really_fast", "true"); + + // OrtEpDevice copies ep_metadata and ep_options. + OrtEpDevice* ep_device = nullptr; + auto* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, + &ep_device); + + factory->ort_api.ReleaseKeyValuePairs(ep_metadata); + factory->ort_api.ReleaseKeyValuePairs(ep_options); + + if (status != nullptr) { + return status; + } + + // register the allocator info required by the EP. + // in this example we register CPU info which is unnecessary unless you need to override the default ORT allocator + // for a non-CPU EP this would be device info (GPU/NPU) and possible host accessible info. + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->cpu_memory_info_.get())); + + ep_devices[num_ep_devices++] = ep_device; + } + + // C++ API equivalent. Throws on error. + //{ + // Ort::ConstHardwareDevice device(devices[i]); + // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + // Ort::KeyValuePairs ep_metadata; + // Ort::KeyValuePairs ep_options; + // ep_metadata.Add("version", "0.1"); + // ep_options.Add("run_really_fast", "true"); + // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; + // ep_devices[num_ep_devices++] = ep_device.release(); + // } + //} + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::CreateEpImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept { + auto* factory = static_cast(this_ptr); + *ep = nullptr; + + if (num_devices != 1) { + // we only registered for CPU and only expected to be selected for one CPU + // if you register for multiple devices (e.g. CPU, GPU and maybe NPU) you will get an entry for each device + // the EP has been selected for. + return factory->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "Example EP only supports selection for one device."); + } + + // Create the execution provider + RETURN_IF_ERROR(factory->ort_api.Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + "Creating Example EP", ORT_FILE, __LINE__, __FUNCTION__)); + + // use properties from the device and ep_metadata if needed + // const OrtHardwareDevice* device = devices[0]; + // const OrtKeyValuePairs* ep_metadata = ep_metadata[0]; + + // Create EP configuration from session options, if needed. + // Note: should not store a direct reference to the session options object as its lifespan is not guaranteed. + std::string ep_context_enable; + RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(factory->ort_api, *session_options, + "ep.context_enable", "0", ep_context_enable)); + + ExampleEp::Config config = {}; + config.enable_ep_context = ep_context_enable == "1"; + + auto dummy_ep = std::make_unique(*factory, factory->ep_name_, config, *logger); + + *ep = dummy_ep.release(); + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleEpFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { + ExampleEp* dummy_ep = static_cast(ep); + delete dummy_ep; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + auto& factory = *static_cast(this_ptr); + *allocator = nullptr; + + // NOTE: The factory implementation can return a shared OrtAllocator* instead of creating a new instance on each call. + // To do this just make ReleaseAllocatorImpl a no-op. + + // NOTE: If OrtMemoryInfo has allocator type (call MemoryInfoGetType) of OrtArenaAllocator, an ORT BFCArena + // will be added to wrap the returned OrtAllocator. The EP is free to implement its own arena, and if it + // wants to do this the OrtMemoryInfo MUST be created with an allocator type of OrtDeviceAllocator. + + // NOTE: The OrtMemoryInfo pointer should only ever be coming straight from an OrtEpDevice, and pointer based + // matching should work. + if (memory_info == factory.cpu_memory_info_.get()) { + // create a CPU allocator. use the basic OrtAllocator for this example. + auto cpu_allocator = std::make_unique(memory_info); + *allocator = cpu_allocator.release(); + } else if (memory_info == factory.default_gpu_memory_info_.get()) { + // create a GPU allocator + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "Example is not implemented."); + } else if (memory_info == factory.host_accessible_gpu_memory_info_.get()) { + // create a pinned/shared memory allocator. Use the real device type (i.e. GPU/NPU) and id and a memory type of + // OrtDeviceMemoryType_HOST_ACCESSIBLE. + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "Example is not implemented."); + } else { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "INTERNAL ERROR! Unknown memory info provided to CreateAllocator. " + "Value did not come directly from an OrtEpDevice returned by this factory."); + } + + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleEpFactory::ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept { + delete static_cast(allocator); +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::CreateDataTransferImpl(OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept { + auto& factory = *static_cast(this_ptr); + *data_transfer = factory.data_transfer_impl_.get(); + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h new file mode 100644 index 0000000000000..8ab67fc9d8ce6 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "ep_data_transfer.h" +#include "example_plugin_ep_utils.h" + +/// +/// Example EP factory that can create an OrtEp and return information about the supported hardware devices. +/// +class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { + public: + ExampleEpFactory(const char* ep_name, ApiPtrs apis); + + OrtDataTransferImpl* GetDataTransfer() const { + return data_transfer_impl_.get(); + } + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; + + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept; + + static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept; + + static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept; + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept; + + const std::string ep_name_; // EP name + const std::string vendor_{"Contoso"}; // EP vendor name + + // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. + using MemoryInfoUniquePtr = std::unique_ptr>; + MemoryInfoUniquePtr cpu_memory_info_; + + // for example purposes. if the EP used GPU, and pinned/shared memory was required for data transfer, these are the + // OrtMemoryInfo instance required for that. + MemoryInfoUniquePtr default_gpu_memory_info_; + MemoryInfoUniquePtr host_accessible_gpu_memory_info_; + + std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep.cc index 9978189267a40..23a61fe9a45cd 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep.cc @@ -1,521 +1,7 @@ -#include -#include -#include -#include -#include -#include -#include +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. -#include "example_plugin_ep_utils.h" - -#define ORT_API_MANUAL_INIT -#include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT - -struct ExampleEp; - -/// -/// Example implementation of ONNX Mul. Does not handle many things like broadcasting. -/// -struct MulKernel { - MulKernel(const OrtApi& ort_api, const OrtLogger& logger) : ort_api(ort_api), logger(logger) {} - - OrtStatus* Compute(OrtKernelContext* kernel_context) { - RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, - OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, - "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); - size_t num_inputs = 0; - RETURN_IF_ERROR(ort_api.KernelContext_GetInputCount(kernel_context, &num_inputs)); - RETURN_IF(num_inputs != 2, ort_api, "Expected 2 inputs for MulKernel"); - - size_t num_outputs = 0; - RETURN_IF_ERROR(ort_api.KernelContext_GetOutputCount(kernel_context, &num_outputs)); - RETURN_IF(num_outputs != 1, ort_api, "Expected 1 output for MulKernel"); - - const OrtValue* input0 = nullptr; - const OrtValue* input1 = nullptr; - RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_context, 0, &input0)); - RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_context, 1, &input1)); - - OrtTensorTypeAndShapeInfo* type_shape0 = nullptr; - OrtTensorTypeAndShapeInfo* type_shape1 = nullptr; - DeferOrtRelease release_type0(&type_shape0, ort_api.ReleaseTensorTypeAndShapeInfo); - DeferOrtRelease release_type1(&type_shape1, ort_api.ReleaseTensorTypeAndShapeInfo); - - RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input0, &type_shape0)); - RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input1, &type_shape1)); - - ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape0, &elem_type)); - RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 inputs"); - - size_t num_dims0 = 0; - size_t num_dims1 = 0; - RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape0, &num_dims0)); - RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape1, &num_dims1)); - RETURN_IF((num_dims0 == 0) || (num_dims1 == 0), ort_api, "Input has 0 dimensions"); - RETURN_IF(num_dims0 != num_dims1, ort_api, "Expected same dimensions for both inputs"); // No broadcasting - - std::vector dims0(num_dims0, 0); - std::vector dims1(num_dims1, 0); - RETURN_IF_ERROR(ort_api.GetDimensions(type_shape0, dims0.data(), dims0.size())); - RETURN_IF_ERROR(ort_api.GetDimensions(type_shape1, dims1.data(), dims1.size())); - RETURN_IF(dims0 != dims1, ort_api, "Expected same dimensions for both inputs"); // No broadcasting. - - const float* input_data0 = nullptr; - const float* input_data1 = nullptr; - RETURN_IF_ERROR(ort_api.GetTensorMutableData(const_cast(input0), (void**)&input_data0)); // No const-correct API? - RETURN_IF_ERROR(ort_api.GetTensorMutableData(const_cast(input1), (void**)&input_data1)); - - OrtValue* output = nullptr; - RETURN_IF_ERROR(ort_api.KernelContext_GetOutput(kernel_context, 0, dims0.data(), dims0.size(), &output)); - - float* output_data = nullptr; - RETURN_IF_ERROR(ort_api.GetTensorMutableData(output, reinterpret_cast(&output_data))); - - int64_t num_elems = 1; - for (int64_t dim : dims0) { - RETURN_IF(dim < 0, ort_api, "Invalid dimension: negative value detected"); - num_elems *= dim; - } - - for (size_t i = 0; i < static_cast(num_elems); ++i) { - output_data[i] = input_data0[i] * input_data1[i]; - } - - return nullptr; - } - - const OrtApi& ort_api; - const OrtLogger& logger; -}; - -/// -/// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. -/// -struct ExampleNodeComputeInfo : OrtNodeComputeInfo { - explicit ExampleNodeComputeInfo(ExampleEp& ep); - - static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, - OrtNodeComputeContext* compute_context, - void** compute_state); - static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, - OrtKernelContext* kernel_context); - static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state); - - ExampleEp& ep; -}; - -struct ApiPtrs { - const OrtApi& ort_api; - const OrtEpApi& ep_api; - const OrtModelEditorApi& model_editor_api; -}; - -/// -/// Example EP that can compile a single Mul operator. -/// -struct ExampleEp : OrtEp, ApiPtrs { - struct Config { - bool enable_ep_context = false; - // Other EP configs (typically extracted from OrtSessionOptions or OrtHardwareDevice(s)) - }; - - ExampleEp(ApiPtrs apis, const std::string& name, const Config& config, const OrtLogger& logger) - : ApiPtrs(apis), name_{name}, config_{config}, logger_{logger} { - // Initialize the execution provider. - auto status = ort_api.Logger_LogMessage(&logger_, - OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, - ("ExampleEp has been created with name " + name_).c_str(), - ORT_FILE, __LINE__, __FUNCTION__); - // ignore status for now - (void)status; - - ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. - GetName = GetNameImpl; - GetCapability = GetCapabilityImpl; - Compile = CompileImpl; - ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; - } - - ~ExampleEp() { - // Clean up the execution provider - } - - static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) { - const auto* ep = static_cast(this_ptr); - return ep->name_.c_str(); - } - - static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) { - ExampleEp* ep = static_cast(this_ptr); - - OrtArrayOfConstObjects* nodes_array = nullptr; - DeferOrtRelease release_nodes_array(&nodes_array, ep->ort_api.ReleaseArrayOfConstObjects); - - size_t num_nodes = 0; - - RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graph, &nodes_array)); - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); - - if (num_nodes == 0) { - return nullptr; // No nodes to process - } - - const void* const* nodes_data = nullptr; - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetData(nodes_array, &nodes_data)); - auto nodes_span = gsl::span(reinterpret_cast(nodes_data), num_nodes); - - std::vector supported_nodes; - - for (const OrtNode* node : nodes_span) { - const char* op_type = nullptr; - RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type)); - - if (std::strncmp(op_type, "Mul", 4) == 0) { - // Check that Mul has inputs/output of type float - OrtArrayOfConstObjects* inputs_array = nullptr; - OrtArrayOfConstObjects* outputs_array = nullptr; - DeferOrtRelease release_inputs(&inputs_array, ep->ort_api.ReleaseArrayOfConstObjects); - DeferOrtRelease release_outputs(&outputs_array, ep->ort_api.ReleaseArrayOfConstObjects); - - RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(node, &inputs_array)); - RETURN_IF_ERROR(ep->ort_api.Node_GetOutputs(node, &outputs_array)); - - size_t num_inputs = 0; - size_t num_outputs = 0; - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(inputs_array, &num_inputs)); - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(outputs_array, &num_outputs)); - RETURN_IF(num_inputs != 2 || num_outputs != 1, ep->ort_api, "Mul should have 2 inputs and 1 output"); - - const void* const* inputs_data = nullptr; - const void* const* outputs_data = nullptr; - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetData(inputs_array, &inputs_data)); - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetData(outputs_array, &outputs_data)); - - std::array is_float = {false, false, false}; - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, static_cast(inputs_data[0]), is_float[0])); - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, static_cast(inputs_data[1]), is_float[1])); - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, static_cast(outputs_data[0]), is_float[2])); - if (!is_float[0] || !is_float[1] || !is_float[2]) { - continue; // Input or output is not of type float - } - - supported_nodes.push_back(node); // Only support a single Mul for now. - break; - } - } - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(), - supported_nodes.size())); - return nullptr; - } - - static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, - _In_ const OrtNode** fused_nodes, _In_ size_t count, - _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes) { - ExampleEp* ep = static_cast(this_ptr); - - if (count != 1) { - return ep->ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single graph"); - } - - OrtArrayOfConstObjects* nodes_array = nullptr; - DeferOrtRelease release_nodes(&nodes_array, ep->ort_api.ReleaseArrayOfConstObjects); - size_t num_nodes = 0; - - RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graphs[0], &nodes_array)); - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); - - if (num_nodes != 1) { - return ep->ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); - } - - const OrtNode* node_to_compile = nullptr; - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetElementAt(nodes_array, 0, - reinterpret_cast(&node_to_compile))); - - const char* node_op_type = nullptr; - RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node_to_compile, &node_op_type)); - - if (std::strncmp(node_op_type, "Mul", 4) != 0) { - return ep->ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); - } - - // Now we know we're compiling a single Mul node. - // Associate the name of the fused node with our MulKernel. - const char* fused_node_name = nullptr; - RETURN_IF_ERROR(ep->ort_api.Node_GetName(fused_nodes[0], &fused_node_name)); - - ep->kernels.emplace(std::string(fused_node_name), std::make_unique(ep->ort_api, ep->logger_)); - - // Update the OrtNodeComputeInfo associated with the graph. - auto node_compute_info = std::make_unique(*ep); - node_compute_infos[0] = node_compute_info.release(); - - // Create EpContext nodes for the fused nodes we compiled. - if (ep->config_.enable_ep_context) { - assert(ep_context_nodes != nullptr); - RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), - gsl::span(ep_context_nodes, count))); - } - - return nullptr; - } - - static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, - OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos) { - (void)this_ptr; - for (size_t i = 0; i < num_node_compute_infos; i++) { - delete node_compute_infos[i]; - } - } - - // Creates EPContext nodes from the given fused nodes. - // This is an example implementation that can be used to generate an EPContext model. However, this example EP - // cannot currently run the EPContext model. - OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, - /*out*/ gsl::span ep_context_nodes) { - assert(fused_nodes.size() == ep_context_nodes.size()); - - // Helper to collect input or output names from an array of OrtValueInfo instances. - auto collect_input_output_names = [&](const OrtArrayOfConstObjects& value_infos, - std::vector& result) -> OrtStatus* { - size_t num_values = 0; - RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetSize(&value_infos, &num_values)); - - std::vector value_names(num_values, nullptr); - - for (size_t i = 0; i < num_values; i++) { - const void* value_info = nullptr; // Is a const OrtValueInfo* - RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(&value_infos, i, &value_info)); - RETURN_IF_ERROR(ort_api.GetValueInfoName(static_cast(value_info), &value_names[i])); - } - - result = std::move(value_names); - return nullptr; - }; - - // Create an "EPContext" node for every fused node. - for (size_t i = 0; i < fused_nodes.size(); ++i) { - const OrtNode* fused_node = fused_nodes[i]; - const char* fused_node_name = nullptr; - - RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &fused_node_name)); - - OrtArrayOfConstObjects* fused_node_inputs = nullptr; - OrtArrayOfConstObjects* fused_node_outputs = nullptr; - DeferOrtRelease defer_release0(&fused_node_inputs, ort_api.ReleaseArrayOfConstObjects); - DeferOrtRelease defer_release1(&fused_node_outputs, ort_api.ReleaseArrayOfConstObjects); - - RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, &fused_node_inputs)); - RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, &fused_node_outputs)); - - std::vector input_names; - std::vector output_names; - - RETURN_IF_ERROR(collect_input_output_names(*fused_node_inputs, /*out*/ input_names)); - RETURN_IF_ERROR(collect_input_output_names(*fused_node_outputs, /*out*/ output_names)); - - int64_t is_main_context = (i == 0); - int64_t embed_mode = 1; - - // Create node attributes. The CreateNode() function copies the attributes, so we have to release them. - std::array attributes = {}; - DeferOrtRelease defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr); - - RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", "binary_data", 1, ORT_OP_ATTR_STRING, &attributes[0])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("main_context", &is_main_context, 1, ORT_OP_ATTR_INT, &attributes[1])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT, &attributes[2])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_sdk_version", "1", 1, ORT_OP_ATTR_STRING, &attributes[3])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("partition_name", fused_node_name, 1, ORT_OP_ATTR_STRING, &attributes[4])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("source", this->name_.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[5])); - - RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name, - input_names.data(), input_names.size(), - output_names.data(), output_names.size(), - attributes.data(), attributes.size(), - &ep_context_nodes[i])); - } - - return nullptr; - } - - std::string name_; - Config config_{}; - const OrtLogger& logger_; - std::unordered_map> kernels; -}; - -// -// Implementation of ExampleNodeComuteInfo -// -ExampleNodeComputeInfo::ExampleNodeComputeInfo(ExampleEp& ep) : ep(ep) { - ort_version_supported = ORT_API_VERSION; - CreateState = CreateStateImpl; - Compute = ComputeImpl; - ReleaseState = ReleaseStateImpl; -} - -OrtStatus* ExampleNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, - OrtNodeComputeContext* compute_context, - void** compute_state) { - auto* node_compute_info = static_cast(this_ptr); - ExampleEp& ep = node_compute_info->ep; - - std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); - auto kernel_it = ep.kernels.find(fused_node_name); - if (kernel_it == ep.kernels.end()) { - std::string message = "Unable to get kernel for fused node with name " + fused_node_name; - return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); - } - - MulKernel& kernel = *kernel_it->second; - *compute_state = &kernel; - return nullptr; -} - -OrtStatus* ExampleNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, - OrtKernelContext* kernel_context) { - (void)this_ptr; - MulKernel& kernel = *reinterpret_cast(compute_state); - return kernel.Compute(kernel_context); -} - -void ExampleNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { - (void)this_ptr; - MulKernel& kernel = *reinterpret_cast(compute_state); - (void)kernel; - // Do nothing for this example. -} - -/// -/// Example EP factory that can create an OrtEp and return information about the supported hardware devices. -/// -struct ExampleEpFactory : OrtEpFactory, ApiPtrs { - ExampleEpFactory(const char* ep_name, ApiPtrs apis) : ApiPtrs(apis), ep_name_{ep_name} { - ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. - GetName = GetNameImpl; - GetVendor = GetVendorImpl; - GetSupportedDevices = GetSupportedDevicesImpl; - CreateEp = CreateEpImpl; - ReleaseEp = ReleaseEpImpl; - } - - static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) { - const auto* factory = static_cast(this_ptr); - return factory->ep_name_.c_str(); - } - - static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) { - const auto* factory = static_cast(this_ptr); - return factory->vendor_.c_str(); - } - - static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) { - size_t& num_ep_devices = *p_num_ep_devices; - auto* factory = static_cast(this_ptr); - - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - // C API - const OrtHardwareDevice& device = *devices[i]; - if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { - // these can be returned as nullptr if you have nothing to add. - OrtKeyValuePairs* ep_metadata = nullptr; - OrtKeyValuePairs* ep_options = nullptr; - factory->ort_api.CreateKeyValuePairs(&ep_metadata); - factory->ort_api.CreateKeyValuePairs(&ep_options); - - // random example using made up values - factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1"); - factory->ort_api.AddKeyValuePair(ep_options, "run_really_fast", "true"); - - // OrtEpDevice copies ep_metadata and ep_options. - auto* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, - &ep_devices[num_ep_devices++]); - - factory->ort_api.ReleaseKeyValuePairs(ep_metadata); - factory->ort_api.ReleaseKeyValuePairs(ep_options); - - if (status != nullptr) { - return status; - } - } - - // C++ API equivalent. Throws on error. - //{ - // Ort::ConstHardwareDevice device(devices[i]); - // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { - // Ort::KeyValuePairs ep_metadata; - // Ort::KeyValuePairs ep_options; - // ep_metadata.Add("version", "0.1"); - // ep_options.Add("run_really_fast", "true"); - // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; - // ep_devices[num_ep_devices++] = ep_device.release(); - // } - //} - } - - return nullptr; - } - - static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, - _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, - _Out_ OrtEp** ep) { - auto* factory = static_cast(this_ptr); - *ep = nullptr; - - if (num_devices != 1) { - // we only registered for CPU and only expected to be selected for one CPU - // if you register for multiple devices (e.g. CPU, GPU and maybe NPU) you will get an entry for each device - // the EP has been selected for. - return factory->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, - "Example EP only supports selection for one device."); - } - - // Create the execution provider - RETURN_IF_ERROR(factory->ort_api.Logger_LogMessage(logger, - OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, - "Creating Example EP", ORT_FILE, __LINE__, __FUNCTION__)); - - // use properties from the device and ep_metadata if needed - // const OrtHardwareDevice* device = devices[0]; - // const OrtKeyValuePairs* ep_metadata = ep_metadata[0]; - - // Create EP configuration from session options, if needed. - // Note: should not store a direct reference to the session options object as its lifespan is not guaranteed. - std::string ep_context_enable; - RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(factory->ort_api, *session_options, - "ep.context_enable", "0", ep_context_enable)); - - ExampleEp::Config config = {}; - config.enable_ep_context = ep_context_enable == "1"; - - auto dummy_ep = std::make_unique(*factory, factory->ep_name_, config, *logger); - - *ep = dummy_ep.release(); - return nullptr; - } - - static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) { - ExampleEp* dummy_ep = static_cast(ep); - delete dummy_ep; - } - - const std::string ep_name_; // EP name - const std::string vendor_{"Contoso"}; // EP vendor name -}; +#include "ep_factory.h" // To make symbols visible on macOS/iOS #ifdef __APPLE__ @@ -531,13 +17,13 @@ extern "C" { EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtApiBase* ort_api_base, OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); - const OrtEpApi* ort_ep_api = ort_api->GetEpApi(); - const OrtModelEditorApi* ort_model_editor_api = ort_api->GetModelEditorApi(); + const OrtEpApi* ep_api = ort_api->GetEpApi(); + const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); // Factory could use registration_name or define its own EP name. std::unique_ptr factory = std::make_unique(registration_name, - ApiPtrs{*ort_api, *ort_ep_api, - *ort_model_editor_api}); + ApiPtrs{*ort_api, *ep_api, + *model_editor_api}); if (max_factories < 1) { return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, @@ -551,7 +37,7 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const } EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { - delete factory; + delete static_cast(factory); return nullptr; } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h index ae0a86bbb7222..e107a94410dba 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h @@ -10,12 +10,12 @@ #include "onnxruntime_cxx_api.h" #undef ORT_API_MANUAL_INIT -#define RETURN_IF_ERROR(fn) \ - do { \ - OrtStatus* status = (fn); \ - if (status != nullptr) { \ - return status; \ - } \ +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return _status; \ + } \ } while (0) #define RETURN_IF(cond, ort_api, msg) \ @@ -25,6 +25,12 @@ } \ } while (0) +struct ApiPtrs { + const OrtApi& ort_api; + const OrtEpApi& ep_api; + const OrtModelEditorApi& model_editor_api; +}; + // Helper to release Ort one or more objects obtained from the public C API at the end of their scope. template struct DeferOrtRelease { @@ -49,6 +55,11 @@ struct DeferOrtRelease { std::function release_func_ = nullptr; }; +struct FloatInitializer { + std::vector shape; + std::vector data; +}; + // Returns an entry in the session option configurations, or a default value if not present. OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessionOptions& session_options, const char* config_key, const std::string& default_val, diff --git a/onnxruntime/test/contrib_ops/greedy_search_test.cc b/onnxruntime/test/contrib_ops/greedy_search_test.cc index 79070f0788f2a..be72fbd460c9b 100644 --- a/onnxruntime/test/contrib_ops/greedy_search_test.cc +++ b/onnxruntime/test/contrib_ops/greedy_search_test.cc @@ -75,9 +75,6 @@ TEST(GreedySearchTest, GptGreedySearchFp16_VocabPadded) { session_options.AppendExecutionProvider_CUDA_V2(cuda_options); } #endif - if (is_rocm) { - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); - } // The following model was obtained by padding the vocabulary size in testdata/transformers/tiny_gpt2_beamsearch_fp16.onnx // (by making beam_size == 1) from 1000 to 1600 (just for illustrative and testing purposes) to see if the greedy search @@ -160,9 +157,6 @@ TEST(GreedySearchTest, GptGreedySearchFp32) { session_options.AppendExecutionProvider_CUDA_V2(cuda_options); } #endif - if (is_rocm) { - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); - } Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_greedysearch_with_init_decoder.onnx"), session_options); diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 97e1de4b6ad16..7b77ca8c69225 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -15,6 +15,7 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/mlas/inc/mlas.h" #include "core/session/inference_session.h" +#include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" #include "test/framework/test_utils.h" #include "test/optimizer/graph_transform_test_builder.h" @@ -103,10 +104,14 @@ void RunTest(const TestOptions& opts, std::vector>&& explicit_eps = {}) { SCOPED_TRACE(opts); - static_assert(std::is_same_v || std::is_same_v, + static_assert(std::is_same_v || std::is_same_v || std::is_same_v, "unexpected type for T1"); - constexpr bool use_float16 = std::is_same_v; +#ifdef USE_CUDA + if (opts.accuracy_level != 0 && !opts.legacy_shape) { + return; // CUDA EP does not handle accuracy level, so only test one level to avoid unnecessary tests. + } +#endif const bool zp_is_4bit = opts.zp_is_4bit || opts.has_g_idx; @@ -163,10 +168,12 @@ void RunTest(const TestOptions& opts, test.AddAttribute("bits", QBits); test.AddAttribute("accuracy_level", opts.accuracy_level); - if constexpr (use_float16) { - test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); - } else { + if constexpr (std::is_same_v) { test.AddInput("A", {M, K}, input0_vals, false); + } else if constexpr (std::is_same::value) { + test.AddInput("A", {M, K}, FloatsToMLFloat16s(input0_vals), false); + } else if constexpr (std::is_same::value) { + test.AddInput("A", {M, K}, FloatsToBFloat16s(input0_vals), false); } test.AddInput("B", {N, k_blocks, blob_size}, input1_vals, true); @@ -174,10 +181,12 @@ void RunTest(const TestOptions& opts, auto scales_shape = opts.legacy_shape ? std::vector{N * k_blocks} : std::vector{N, k_blocks}; - if constexpr (use_float16) { - test.AddInput("scales", scales_shape, ToFloat16(scales), true); - } else { + if constexpr (std::is_same::value) { test.AddInput("scales", scales_shape, scales, true); + } else if constexpr (std::is_same::value) { + test.AddInput("scales", scales_shape, FloatsToMLFloat16s(scales), true); + } else if constexpr (std::is_same::value) { + test.AddInput("scales", scales_shape, FloatsToBFloat16s(scales), true); } if (opts.has_zero_point) { @@ -198,10 +207,12 @@ void RunTest(const TestOptions& opts, ind -= q_scale_size / N + 1; } - if constexpr (use_float16) { - test.AddInput("zero_points", scales_shape, ToFloat16(zp_f), true); - } else { + if constexpr (std::is_same_v) { test.AddInput("zero_points", scales_shape, zp_f, true); + } else if constexpr (std::is_same_v) { + test.AddInput("zero_points", scales_shape, FloatsToMLFloat16s(zp_f), true); + } else if constexpr (std::is_same_v) { + test.AddInput("zero_points", scales_shape, FloatsToBFloat16s(zp_f), true); } } } else { @@ -225,19 +236,23 @@ void RunTest(const TestOptions& opts, } if (bias.has_value()) { - if constexpr (use_float16) { - test.AddInput("bias", bias_shape, ToFloat16(*bias), true); - } else { + if constexpr (std::is_same::value) { test.AddInput("bias", bias_shape, *bias, true); + } else if constexpr (std::is_same::value) { + test.AddInput("bias", bias_shape, FloatsToMLFloat16s(*bias), true); + } else if constexpr (std::is_same::value) { + test.AddInput("bias", bias_shape, FloatsToBFloat16s(*bias), true); } } else { test.AddOptionalInputEdge(); } - if constexpr (use_float16) { - test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); - } else { + if constexpr (std::is_same::value) { test.AddOutput("Y", {M, N}, expected_vals); + } else if constexpr (std::is_same::value) { + test.AddOutput("Y", {M, N}, FloatsToMLFloat16s(expected_vals)); + } else if constexpr (std::is_same::value) { + test.AddOutput("Y", {M, N}, FloatsToBFloat16s(expected_vals)); } if (opts.output_abs_error.has_value()) { @@ -258,17 +273,27 @@ void RunTest(const TestOptions& opts, } // namespace template -void TestMatMulNBitsTyped() { +void TestMatMulNBitsTyped(std::optional abs_error = std::nullopt, + std::optional rel_error = std::nullopt) { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; base_opts.block_size = block_size; base_opts.accuracy_level = accuracy_level; + base_opts.legacy_shape = legacy_shape; - if (base_opts.accuracy_level == 4) { + if (abs_error.has_value()) { + base_opts.output_abs_error = *abs_error; + } else if (base_opts.accuracy_level == 4) { base_opts.output_abs_error = 0.1f; - base_opts.output_rel_error = 0.02f; } else if constexpr (std::is_same::value) { base_opts.output_abs_error = 0.055f; + } + + if (rel_error.has_value()) { + base_opts.output_rel_error = *rel_error; + } else if (base_opts.accuracy_level == 4) { + base_opts.output_rel_error = 0.02f; + } else if constexpr (std::is_same::value) { base_opts.output_rel_error = 0.02f; } @@ -311,7 +336,8 @@ void TestMatMulNBitsTyped() { { TestOptions opts = base_opts; - opts.has_zero_point = true, opts.zp_is_4bit = false; + opts.has_zero_point = true; + opts.zp_is_4bit = false; RunTest(opts); } #endif // !defined(USE_DML) && !defined(USE_WEBGPU) @@ -473,27 +499,32 @@ namespace { // Legacy test function. // This has too many parameters of the same type that must be specified in the correct order. // Consider using the overload with a TestOptions parameter. -void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level, - bool has_zeropoint, bool use_float16, bool has_g_idx = false, - bool zp_is_4bit = true, float fp16_abs_error = 0.02f, bool has_bias = false) { +template +void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zeropoint, bool zp_is_4bit = true, + float abs_error = 0.f, bool has_g_idx = false, bool has_bias = false) { TestOptions opts{}; opts.M = M; opts.N = N; opts.K = K; opts.block_size = block_size; - opts.accuracy_level = accuracy_level; + opts.accuracy_level = 0; opts.has_zero_point = has_zeropoint; opts.zp_is_4bit = zp_is_4bit; opts.has_g_idx = has_g_idx; opts.has_bias = has_bias; - if (use_float16) { - opts.output_abs_error = fp16_abs_error; - opts.output_rel_error = use_float16 ? 0.001f : 0.0005f; + if (abs_error > 0.f) { + opts.output_abs_error = abs_error; + } + + if (std::is_same_v) { + opts.output_rel_error = 0.001f; + } else if (std::is_same_v) { + opts.output_rel_error = 0.0005f; } std::vector> execution_providers; - if (use_float16) { + if (std::is_same_v) { #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); #endif @@ -516,28 +547,27 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura RunTest(opts, std::move(execution_providers)); } } -} // namespace -TEST(MatMulNBits, Float16Cuda) { -#if defined(USE_CUDA) || defined(USE_ROCM) - auto has_gidx_options = {true, false}; -#else - auto has_gidx_options = {false}; -#endif +constexpr bool kPipelineMode = true; // CI pipeline? +} // namespace - for (auto M : {1, 2, 100}) { - for (auto N : {1, 2, 32, 288}) { - for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { - for (auto block_size : {16, 32, 64, 128}) { - for (auto has_gidx : has_gidx_options) { -#if defined(USE_DML) - RunTest(M, N, K, block_size, 0, false, true, has_gidx, true, 0.04f); -#elif defined(USE_WEBGPU) - RunTest(M, N, K, block_size, 0, false, true, has_gidx, true, 0.03f); -#else - RunTest(M, N, K, block_size, 0, false, true, has_gidx); - RunTest(M, N, K, block_size, 0, true, true, has_gidx, false); -#endif +TEST(MatMulNBits, Float16_Comprehensive) { + if constexpr (kPipelineMode) { + GTEST_SKIP() << "Skipping in pipeline mode"; // This test has too many combinations. Skip it in CI pipeline. + } else { + constexpr float abs_error = 0.02f; + + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + for (auto has_g_idx : {false, true}) { + for (auto has_zero_point : {false, true}) { + for (auto is_zero_point_4bit : {false, true}) { + RunTest(M, N, K, block_size, has_zero_point, is_zero_point_4bit, abs_error, has_g_idx); + } + } + } } } } @@ -545,69 +575,132 @@ TEST(MatMulNBits, Float16Cuda) { } } -TEST(MatMulNBits, Float16Large) { +TEST(MatMulNBits, Float16_Large) { #ifdef USE_DML // For some reason, the A10 machine that runs these tests during CI has a much bigger error than all retail // machines we tested on. All consumer-grade machines from Nvidia/AMD/Intel seem to pass these tests with an // absolute error of 0.08, but the A10 has errors going as high as 0.22. Ultimately, given the large number // of elements in this test, ULPs should probably be used instead of absolute/relative tolerances. - float abs_error = 0.3f; + constexpr float abs_error = 0.3f; #else - float abs_error = 0.1f; + constexpr float abs_error = 0.1f; #endif + constexpr bool zp_is_4bit = true; + for (auto block_size : {16, 32, 64, 128}) { - for (auto symmetric : {false, true}) { - RunTest(1, 4096, 4096, block_size, 0, symmetric, true, false, true, abs_error); - RunTest(1, 4096, 11008, block_size, 0, symmetric, true, false, true, abs_error); - RunTest(1, 11008, 4096, block_size, 0, symmetric, true, false, true, abs_error); + for (auto has_zeropoint : {false, true}) { + RunTest(1, 4096, 4096, block_size, has_zeropoint, zp_is_4bit, abs_error); + RunTest(1, 4096, 11008, block_size, has_zeropoint, zp_is_4bit, abs_error); + RunTest(1, 11008, 4096, block_size, has_zeropoint, zp_is_4bit, abs_error); } } } #ifdef USE_CUDA TEST(MatMulNBits, Fp16_Int4_Int4ZeroPoint) { - float abs_error = 0.1f; - constexpr bool use_float16 = true; - constexpr bool has_g_idx = false; + constexpr float abs_error = 0.1f; constexpr bool zp_is_4bit = true; constexpr bool has_zeropoint = true; + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, has_zeropoint, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, has_zeropoint, zp_is_4bit, abs_error); + } + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; for (auto block_size : {64, 128}) { - RunTest(1, 256, 1024, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); - RunTest(32, 1024, 2048, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + RunTest(1, 256, 1024, block_size, has_zeropoint, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, has_zeropoint, zp_is_4bit, abs_error); } } TEST(MatMulNBits, Fp16_Int4_Fp16ZeroPoint) { - float abs_error = 0.1f; - constexpr bool use_float16 = true; - constexpr bool has_g_idx = false; + constexpr float abs_error = 0.1f; constexpr bool zp_is_4bit = false; constexpr bool has_zeropoint = true; ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; for (auto block_size : {64, 128}) { - RunTest(1, 256, 1024, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); - RunTest(32, 1024, 2048, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + RunTest(1, 256, 1024, block_size, has_zeropoint, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, has_zeropoint, zp_is_4bit, abs_error); + } +} + +TEST(MatMulNBits, BFloat16_Int4_Int4ZeroPoint) { + constexpr float abs_error = 0.1f; + constexpr bool zp_is_4bit = true; + constexpr bool has_zeropoint = true; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, has_zeropoint, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, has_zeropoint, zp_is_4bit, abs_error); + } + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, has_zeropoint, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, has_zeropoint, zp_is_4bit, abs_error); + } +} + +TEST(MatMulNBits, BFloat16_Int4_BFloat16ZeroPoint) { + if (!HasCudaEnvironment(800)) { + GTEST_SKIP() << "Skipping BFloat16 8-bit MatMul tests on CUDA < 8.0"; + } + + constexpr float abs_error = 0.1f; + constexpr bool zp_is_4bit = false; + constexpr bool has_zeropoint = true; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, has_zeropoint, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, has_zeropoint, zp_is_4bit, abs_error); + } + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, has_zeropoint, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, has_zeropoint, zp_is_4bit, abs_error); } } TEST(MatMulNBits, Fp16_Int4_NoZeroPoint) { - float abs_error = 0.1f; - constexpr bool use_float16 = true; - constexpr bool has_g_idx = false; + constexpr float abs_error = 0.1f; + constexpr bool zp_is_4bit = true; + constexpr bool has_zeropoint = false; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, has_zeropoint, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, has_zeropoint, zp_is_4bit, abs_error); + } +} + +TEST(MatMulNBits, BFloat16_Int4_NoZeroPoint) { + if (!HasCudaEnvironment(800)) { + GTEST_SKIP() << "Skipping BFloat16 8-bit MatMul tests on CUDA < 8.0"; + } + + constexpr float abs_error = 0.5f; constexpr bool zp_is_4bit = true; constexpr bool has_zeropoint = false; + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, has_zeropoint, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, has_zeropoint, zp_is_4bit, abs_error); + } + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; for (auto block_size : {64, 128}) { - RunTest(1, 256, 1024, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); - RunTest(32, 1024, 2048, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + RunTest(1, 256, 1024, block_size, has_zeropoint, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, has_zeropoint, zp_is_4bit, abs_error); } } #endif diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index fc802054036d8..bd7ee13aeae31 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -14,6 +14,7 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/mlas/inc/mlas.h" #include "core/session/inference_session.h" +#include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" #include "test/framework/test_utils.h" #include "test/optimizer/graph_transform_test_builder.h" @@ -64,6 +65,12 @@ template void RunTest8Bits(const TestOptions8Bits& opts) { SCOPED_TRACE(opts); +#ifdef USE_CUDA + if (opts.accuracy_level != 0) { + return; // CUDA EP does not handle accuracy level, so only test one level to avoid unnecessary tests. + } +#endif + const int64_t M = opts.M, K = opts.K, N = opts.N; @@ -140,8 +147,10 @@ void RunTest8Bits(const TestOptions8Bits& opts) { test.AddAttribute("accuracy_level", opts.accuracy_level); if constexpr (std::is_same::value) { test.AddInput("A", {M, K}, input0_fp32_vals, false); - } else { + } else if constexpr (std::is_same::value) { test.AddInput("A", {M, K}, FloatsToMLFloat16s(input0_fp32_vals), false); + } else if constexpr (std::is_same::value) { + test.AddInput("A", {M, K}, FloatsToBFloat16s(input0_fp32_vals), false); } int64_t k_blocks = (K + opts.block_size - 1) / opts.block_size; @@ -149,8 +158,10 @@ void RunTest8Bits(const TestOptions8Bits& opts) { if constexpr (std::is_same::value) { test.AddInput("scales", {N, static_cast(q_scale_size) / N}, scales, true); - } else { + } else if constexpr (std::is_same::value) { test.AddInput("scales", {N, static_cast(q_scale_size) / N}, FloatsToMLFloat16s(scales), true); + } else if constexpr (std::is_same::value) { + test.AddInput("scales", {N, static_cast(q_scale_size) / N}, FloatsToBFloat16s(scales), true); } if (opts.has_zero_point) { @@ -165,8 +176,10 @@ void RunTest8Bits(const TestOptions8Bits& opts) { if (bias.has_value()) { if constexpr (std::is_same::value) { test.AddInput("bias", bias_shape, *bias, true); - } else { + } else if constexpr (std::is_same::value) { test.AddInput("bias", bias_shape, FloatsToMLFloat16s(*bias), true); + } else if constexpr (std::is_same::value) { + test.AddInput("bias", bias_shape, FloatsToBFloat16s(*bias), true); } } else { test.AddOptionalInputEdge(); @@ -174,8 +187,10 @@ void RunTest8Bits(const TestOptions8Bits& opts) { if constexpr (std::is_same::value) { test.AddOutput("Y", {M, N}, expected_vals); - } else { + } else if constexpr (std::is_same::value) { test.AddOutput("Y", {M, N}, FloatsToMLFloat16s(expected_vals)); + } else if constexpr (std::is_same::value) { + test.AddOutput("Y", {M, N}, FloatsToBFloat16s(expected_vals)); } if (opts.output_abs_error.has_value()) { @@ -207,7 +222,7 @@ void RunTest8Bits(const TestOptions8Bits& opts) { #endif } -template +template void TestMatMul8BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) { TestOptions8Bits base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; @@ -251,6 +266,14 @@ void TestMatMul8BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) { } // namespace TEST(MatMulNBits, Float32_8b_AccuracyLevel4) { +#if defined(__ANDROID__) && defined(MLAS_TARGET_AMD64_IX86) + // Fails on Android CI build on Linux host machine: + // [ RUN ] MatMulNBits.Float32_8b_AccuracyLevel4 + // Test was not executed. + // Trap + // TODO investigate failure + GTEST_SKIP() << "Skipping test on Android x86_64 (emulator)."; +#endif TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); @@ -286,6 +309,14 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel4) { } TEST(MatMulNBits, Float32_8b_AccuracyLevel1) { +#if defined(__ANDROID__) && defined(MLAS_TARGET_AMD64_IX86) + // Fails on Android CI build on Linux host machine: + // [ RUN ] MatMulNBits.Float32_8b_AccuracyLevel1 + // Trap + // TODO investigate failure + GTEST_SKIP() << "Skipping test on Android x86_64 (emulator)."; +#endif + // At the time of writing these tests, Fp32 activations + 8 bit weights + Accuracy level 1 // do not have MLAS optimized kernels on any platform and hence this will use the "unpacked" // compute mode (i.e.) de-quantize the 8 bit weights to fp32 and invoke vanilla fp32 Gemm @@ -324,7 +355,7 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel1) { TestMatMul8BitsTyped(); } -#if defined(USE_CUDA) || defined(USE_WEBGPU) +#if defined(USE_WEBGPU) TEST(MatMulNBits, Float16_8b_AccuracyLevel4) { constexpr float abs_error = 0.055f; constexpr float rel_error = 0.02f; @@ -342,10 +373,37 @@ TEST(MatMulNBits, Fp16_Int8_Cuda) { ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; - TestMatMul8BitsTyped(abs_error, rel_error); - TestMatMul8BitsTyped(abs_error, rel_error); - TestMatMul8BitsTyped(abs_error, rel_error); - TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); +} + +TEST(MatMulNBits, BFloat16_8bits) { + if (!HasCudaEnvironment(800)) { + GTEST_SKIP() << "Skipping BFloat16 8-bit MatMul tests on CUDA < 8.0"; + } + + constexpr float abs_error = 0.055f; + constexpr float rel_error = 0.02f; + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); +} + +TEST(MatMulNBits, BFloat16_Int8_Gemm_Cuda) { + if (!HasCudaEnvironment(800)) { + GTEST_SKIP() << "Skipping BFloat16 8-bit MatMul tests on CUDA < 8.0"; + } + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + constexpr float abs_error = 0.5f; + constexpr float rel_error = 0.05f; + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); } #endif diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 64935929db070..d47ebce483be4 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -287,8 +287,11 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::spanName() == graph_output->Name(); }); bool is_const_initializer = false; - const ONNX_NAMESPACE::TensorProto* initializer = graph_viewer.GetGraph().GetInitializer(value_name, true, - is_const_initializer); + OrtValue initializer_value; + const ONNX_NAMESPACE::TensorProto* initializer = graph_viewer.GetGraph().GetInitializer(value_name, + initializer_value, + is_const_initializer, + /*check_outer_scope*/ true); bool can_override_initializer = graph_viewer.CanOverrideInitializer(); bool api_is_req_graph_input = false; @@ -449,6 +452,79 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ CheckValueInfosCApi(graph_viewer, api_node_outputs, output_node_args); + // Check node attributes + const auto& node_attrs = node->GetAttributes(); + + if (node_attrs.size() > 0) { + OrtArrayOfConstObjects* api_node_attributes = nullptr; + DeferOrtRelease release_node_attributes(&api_node_attributes, + ort_api.ReleaseArrayOfConstObjects); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(api_node, &api_node_attributes)); + CheckArrayObjectType(api_node_attributes, ORT_TYPE_TAG_OrtOpAttr); + + size_t attr_idx = 0; + for (const auto& node_attr : node_attrs) { + const OrtOpAttr* api_node_attr = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetElementAt(api_node_attributes, attr_idx, + reinterpret_cast(&api_node_attr))); + ASSERT_NE(api_node_attr, nullptr); + + api_node_attr = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributeByName(api_node, node_attr.first.c_str(), &api_node_attr)); + ASSERT_NE(api_node_attr, nullptr); + + const char* api_node_attr_name = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetName(api_node_attr, &api_node_attr_name)); + ASSERT_STREQ(api_node_attr_name, node_attr.first.c_str()); + + OrtOpAttrType api_node_attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + + // It's possible that the type is defined in ONNX::AttributeProto_AttributeType but not in OrtOpAttrType, since the two are not in a 1:1 mapping. + // In such cases, OpAttr_GetType will return a non-null status, and we simply skip the check here. + OrtStatusPtr status = ort_api.OpAttr_GetType(api_node_attr, &api_node_attr_type); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + continue; + } + + ONNX_NAMESPACE::AttributeProto_AttributeType node_attr_type = node_attr.second.type(); + switch (node_attr_type) { + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_UNDEFINED: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_UNDEFINED); + break; + } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_INT); + break; + } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_INTS); + break; + } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_FLOAT); + break; + } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_FLOATS); + break; + } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_STRING); + break; + } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_STRINGS); + break; + } + default: + // The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail. + ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit.")); + } + attr_idx++; + } + } + // Check node subgraphs std::vector> node_subgraphs = node->GetSubgraphs(); diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc new file mode 100644 index 0000000000000..36b7f2965b483 --- /dev/null +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/ep_plugin_provider_interfaces.h" + +#include "gsl/gsl" +#include "gtest/gtest.h" + +#include "core/session/abi_devices.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/util/include/asserts.h" + +namespace onnxruntime::test { + +// Helper class to access public ORT APIs. +struct ApiPtrs { + ApiPtrs() : ort_api{::OrtGetApiBase()->GetApi(ORT_API_VERSION)}, + ep_api{ort_api->GetEpApi()} { + } + + const gsl::not_null ort_api; + const gsl::not_null ep_api; +}; + +// Normally, a plugin EP would be implemented in a separate library. +// The `test_plugin_ep` namespace contains a local implementation intended for unit testing. +namespace test_plugin_ep { + +struct TestOrtEp : ::OrtEp, ApiPtrs { + TestOrtEp() : ::OrtEp{}, ApiPtrs{} { + ort_version_supported = ORT_API_VERSION; + + GetName = GetNameImpl; + + // Individual tests should fill out the other function pointers as needed. + } + + static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) { + constexpr const char* ep_name = "TestOrtEp"; + return ep_name; + } +}; + +// This factory doesn't do anything other than implement ReleaseEp(). +// It is only used to create the UniqueOrtEp that is required by PluginExecutionProvider. +struct TestOrtEpFactory : ::OrtEpFactory { + TestOrtEpFactory() : ::OrtEpFactory{} { + ort_version_supported = ORT_API_VERSION; + ReleaseEp = ReleaseEpImpl; + } + + static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) { + delete static_cast(ep); + } +}; + +static TestOrtEpFactory g_test_ort_ep_factory{}; + +struct MakeTestOrtEpResult { + std::unique_ptr ep; // the IExecutionProvider wrapping the TestOrtEp + gsl::not_null ort_ep; // the wrapped TestOrtEp, owned by `ep` +}; + +// Creates an IExecutionProvider that wraps a TestOrtEp. +// The TestOrtEp is also exposed so that tests can manipulate its function pointers directly. +MakeTestOrtEpResult MakeTestOrtEp() { + auto ort_ep_raw = std::make_unique().release(); + auto ort_ep = UniqueOrtEp(ort_ep_raw, OrtEpDeleter{g_test_ort_ep_factory}); + auto ort_session_options = Ort::SessionOptions{}; + auto ort_ep_device = OrtEpDevice{}; + std::vector ep_devices{&ort_ep_device}; + + auto ep = std::make_unique(std::move(ort_ep), + *static_cast(ort_session_options), + g_test_ort_ep_factory, + ep_devices); + + auto result = MakeTestOrtEpResult{std::move(ep), ort_ep_raw}; + return result; +} + +} // namespace test_plugin_ep + +TEST(PluginExecutionProviderTest, GetPreferredLayout) { + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); + + { + ort_ep->GetPreferredDataLayout = nullptr; + ASSERT_EQ(ep->GetPreferredLayout(), DataLayout::NCHW); + } + + { + auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { + *preferred_data_layout = OrtEpDataLayout::OrtEpDataLayout_NCHW; + return nullptr; + }; + ort_ep->GetPreferredDataLayout = prefer_nhwc_fn; + ASSERT_EQ(ep->GetPreferredLayout(), DataLayout::NCHW); + } + +#if !defined(ORT_NO_EXCEPTIONS) + { + auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { + *preferred_data_layout = static_cast(-1); + return nullptr; + }; + ort_ep->GetPreferredDataLayout = invalid_layout_fn; + ASSERT_THROW(ep->GetPreferredLayout(), OnnxRuntimeException); + } + + { + auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) -> ::OrtStatus* { + auto* test_ort_ep = static_cast(this_ptr); + return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "I can't decide what data layout I prefer."); + }; + ort_ep->GetPreferredDataLayout = failing_fn; + ASSERT_THROW(ep->GetPreferredLayout(), OnnxRuntimeException); + } +#endif // !defined(ORT_NO_EXCEPTIONS) +} + +TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); + + { + ort_ep->ShouldConvertDataLayoutForOp = nullptr; + ASSERT_EQ(ep->ShouldConvertDataLayoutForOp("", "Conv", DataLayout::NHWC), std::nullopt); + } + + { + auto custom_nhwc_op_determination_fn = [](OrtEp* /*this_ptr*/, + const char* /*node_domain*/, + const char* node_op_type, + OrtEpDataLayout target_data_layout, + int* should_convert) -> ::OrtStatus* { + EXPECT_EQ(target_data_layout, OrtEpDataLayout::OrtEpDataLayout_NHWC); + + if (node_op_type == std::string_view{"Conv"}) { + *should_convert = 1; + } else if (node_op_type == std::string_view{"BatchNormalization"}) { + *should_convert = 0; + } else { + *should_convert = -1; + } + return nullptr; + }; + ort_ep->ShouldConvertDataLayoutForOp = custom_nhwc_op_determination_fn; + + std::optional should_convert{}; + + should_convert = ep->ShouldConvertDataLayoutForOp("", "Conv", DataLayout::NHWC); + ASSERT_NE(should_convert, std::nullopt); + ASSERT_EQ(*should_convert, true); + + should_convert = ep->ShouldConvertDataLayoutForOp("", "BatchNormalization", DataLayout::NHWC); + ASSERT_NE(should_convert, std::nullopt); + ASSERT_EQ(*should_convert, false); + + should_convert = ep->ShouldConvertDataLayoutForOp("", "GridSample", DataLayout::NHWC); + ASSERT_EQ(should_convert, std::nullopt); + } + +#if !defined(ORT_NO_EXCEPTIONS) + { + auto failing_fn = [](OrtEp* this_ptr, + const char* /*node_domain*/, + const char* /*node_op_type*/, + OrtEpDataLayout /*target_data_layout*/, + int* /*should_convert*/) -> ::OrtStatus* { + auto* test_ort_ep = static_cast(this_ptr); + return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, + "To convert to NHWC or not to convert to NHWC..."); + }; + ort_ep->ShouldConvertDataLayoutForOp = failing_fn; + ASSERT_THROW(ep->ShouldConvertDataLayoutForOp("", "Conv", DataLayout::NHWC), OnnxRuntimeException); + } +#endif // !defined(ORT_NO_EXCEPTIONS) +} + +} // namespace onnxruntime::test diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 1e6167f862ea1..2ce3c4859394d 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -44,6 +44,7 @@ #include "core/providers/rocm/rocm_provider_factory.h" #include "core/providers/rocm/gpu_data_transfer.h" #endif +#include "core/session/allocator_adapters.h" #include "core/session/environment.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" @@ -2752,30 +2753,40 @@ TEST(InferenceSessionTests, AllocatorSharing_EnsureSessionsUseSameOrtCreatedAllo [mem_info](int) { return std::make_unique(mem_info); }, 0, use_arena}; - AllocatorPtr allocator_ptr = CreateAllocator(device_info); - st = env->RegisterAllocator(allocator_ptr); + // convert to OrtAllocator* to use the public method used to register allocators by the ORT API + OrtAllocatorImplWrappingIAllocator ort_allocator(CreateAllocator(device_info)); + st = env->RegisterAllocator(&ort_allocator); ASSERT_STATUS_OK(st); - // create sessions to share the allocator - SessionOptions so1; - ASSERT_STATUS_OK(so1.config_options.AddConfigEntry(kOrtSessionOptionsConfigUseEnvAllocators, "1")); - InferenceSessionTestSharingAllocator sess1(so1, *env); - ASSERT_STATUS_OK(sess1.Load(MODEL_URI)); - ASSERT_STATUS_OK(sess1.Initialize()); + { + // create sessions to share the allocator + SessionOptions so1; + ASSERT_STATUS_OK(so1.config_options.AddConfigEntry(kOrtSessionOptionsConfigUseEnvAllocators, "1")); + InferenceSessionTestSharingAllocator sess1(so1, *env); + ASSERT_STATUS_OK(sess1.Load(MODEL_URI)); + ASSERT_STATUS_OK(sess1.Initialize()); - SessionOptions so2; - ASSERT_STATUS_OK(so2.config_options.AddConfigEntry(kOrtSessionOptionsConfigUseEnvAllocators, "1")); - InferenceSessionTestSharingAllocator sess2(so2, *env); - ASSERT_STATUS_OK(sess2.Load(MODEL_URI)); - ASSERT_STATUS_OK(sess2.Initialize()); + SessionOptions so2; + ASSERT_STATUS_OK(so2.config_options.AddConfigEntry(kOrtSessionOptionsConfigUseEnvAllocators, "1")); + InferenceSessionTestSharingAllocator sess2(so2, *env); + ASSERT_STATUS_OK(sess2.Load(MODEL_URI)); + ASSERT_STATUS_OK(sess2.Initialize()); - // This line ensures the allocator in the session is the same as that in the env - ASSERT_EQ(sess1.GetSessionState().GetAllocator(mem_info).get(), - allocator_ptr.get()); + // Need to undo the wrapping that happens in Environment::RegisterAllocator to be able to compare the pointers + const OrtAllocator* session_allocator = reinterpret_cast( + sess1.GetSessionState().GetAllocator(mem_info).get()) + ->GetWrappedOrtAllocator(); - // This line ensures the underlying IAllocator* is the same across 2 sessions. - ASSERT_EQ(sess1.GetSessionState().GetAllocator(mem_info).get(), - sess2.GetSessionState().GetAllocator(mem_info).get()); + // This line ensures the allocator in the session is the same as that in the env + ASSERT_EQ(session_allocator, &ort_allocator); + + // This line ensures the underlying IAllocator* is the same across 2 sessions. + ASSERT_EQ(sess1.GetSessionState().GetAllocator(mem_info).get(), + sess2.GetSessionState().GetAllocator(mem_info).get()); + } + + // registered as the allocator will become invalid before the environment is destroyed + ASSERT_STATUS_OK(env->UnregisterAllocator(mem_info)); } // Ensure sessions don't use the same allocator. It uses ORT created allocator. @@ -2800,8 +2811,8 @@ TEST(InferenceSessionTests, AllocatorSharing_EnsureSessionsDontUseSameOrtCreated [mem_info](int) { return std::make_unique(mem_info); }, 0, use_arena}; - AllocatorPtr allocator_ptr = CreateAllocator(device_info); - st = env->RegisterAllocator(allocator_ptr); + OrtAllocatorImplWrappingIAllocator ort_allocator(CreateAllocator(device_info)); + st = env->RegisterAllocator(&ort_allocator); ASSERT_STATUS_OK(st); // create sessions to share the allocator @@ -2817,9 +2828,13 @@ TEST(InferenceSessionTests, AllocatorSharing_EnsureSessionsDontUseSameOrtCreated ASSERT_STATUS_OK(sess2.Load(MODEL_URI)); ASSERT_STATUS_OK(sess2.Initialize()); + // Need to undo the wrapping that happens in Environment::RegisterAllocator to be able to compare the pointers + const OrtAllocator* session_allocator = reinterpret_cast( + sess1.GetSessionState().GetAllocator(mem_info).get()) + ->GetWrappedOrtAllocator(); + // This line ensures the allocator in the session is the same as that in the env - ASSERT_EQ(sess1.GetSessionState().GetAllocator(mem_info).get(), - allocator_ptr.get()); + ASSERT_EQ(session_allocator, &ort_allocator); // This line ensures the underlying OrtAllocator* is the same across 2 sessions. ASSERT_NE(sess1.GetSessionState().GetAllocator(mem_info).get(), diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 16555eafcd897..a74ecacc1f26e 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1393,59 +1393,7 @@ std::unique_ptr> GetBrokenTests(const std::string& provider } if (provider_name == "qnn") { - broken_tests->insert({"gemm_default_no_bias", "result differs"}); broken_tests->insert({"resize_downsample_scales_linear", "result differs"}); - broken_tests->insert({"resize_downsample_scales_linear_antialias", "result differs"}); - broken_tests->insert({"resize_downsample_sizes_linear_antialias", "result differs"}); - broken_tests->insert({"sce_NCd1_mean_weight_negative_ii", "result differs"}); - broken_tests->insert({"sce_NCd1_mean_weight_negative_ii_expanded", "result differs"}); - broken_tests->insert({"sce_NCd1_mean_weight_negative_ii_log_prob", "result differs"}); - broken_tests->insert({"sce_NCd1_mean_weight_negative_ii_log_prob_expanded", "result differs"}); - broken_tests->insert({"sce_mean", "result differs"}); - broken_tests->insert({"sce_mean_3d", "result differs"}); - broken_tests->insert({"sce_mean_3d_expanded", "result differs"}); - broken_tests->insert({"sce_mean_3d_log_prob", "result differs"}); - broken_tests->insert({"sce_mean_3d_log_prob_expanded", "result differs"}); - broken_tests->insert({"sce_mean_expanded", "result differs"}); - broken_tests->insert({"sce_mean_log_prob", "result differs"}); - broken_tests->insert({"sce_mean_log_prob_expanded", "result differs"}); - broken_tests->insert({"sce_mean_no_weight_ii", "result differs"}); - broken_tests->insert({"sce_mean_no_weight_ii_3d", "result differs"}); - broken_tests->insert({"sce_mean_no_weight_ii_3d_expanded", "result differs"}); - broken_tests->insert({"sce_mean_no_weight_ii_3d_log_prob", "result differs"}); - broken_tests->insert({"sce_mean_no_weight_ii_3d_log_prob_expanded", "result differs"}); - broken_tests->insert({"sce_mean_no_weight_ii_4d", "result differs"}); - broken_tests->insert({"sce_mean_no_weight_ii_4d_expanded", "result differs"}); - broken_tests->insert({"sce_mean_no_weight_ii_4d_log_prob", "result differs"}); - broken_tests->insert({"sce_mean_no_weight_ii_4d_log_prob_expanded", "result differs"}); - broken_tests->insert({"sce_mean_no_weight_ii_expanded", "result differs"}); - broken_tests->insert({"sce_mean_no_weight_ii_log_prob", "result differs"}); - broken_tests->insert({"sce_mean_no_weight_ii_log_prob_expanded", "result differs"}); - broken_tests->insert({"sce_mean_weight", "result differs"}); - broken_tests->insert({"sce_mean_weight_expanded", "result differs"}); - broken_tests->insert({"sce_mean_weight_ii", "result differs"}); - broken_tests->insert({"sce_mean_weight_ii_3d", "result differs"}); - broken_tests->insert({"sce_mean_weight_ii_3d_expanded", "result differs"}); - broken_tests->insert({"sce_mean_weight_ii_3d_log_prob", "result differs"}); - broken_tests->insert({"sce_mean_weight_ii_3d_log_prob_expanded", "result differs"}); - broken_tests->insert({"sce_mean_weight_ii_4d", "result differs"}); - broken_tests->insert({"sce_mean_weight_ii_4d_expanded", "result differs"}); - broken_tests->insert({"sce_mean_weight_ii_4d_log_prob", "result differs"}); - broken_tests->insert({"sce_mean_weight_ii_4d_log_prob_expanded", "result differs"}); - broken_tests->insert({"sce_mean_weight_ii_expanded", "result differs"}); - broken_tests->insert({"sce_mean_weight_ii_log_prob", "result differs"}); - broken_tests->insert({"sce_mean_weight_ii_log_prob_expanded", "result differs"}); - broken_tests->insert({"sce_mean_weight_log_prob", "result differs"}); - broken_tests->insert({"sce_mean_weight_log_prob_expanded", "result differs"}); - broken_tests->insert({"sce_none", "result differs"}); - broken_tests->insert({"sce_none_expanded", "result differs"}); - broken_tests->insert({"sce_none_log_prob", "result differs"}); - broken_tests->insert({"sce_none_log_prob_expanded", "result differs"}); - broken_tests->insert({"sce_sum", "result differs"}); - broken_tests->insert({"sce_sum_expanded", "result differs"}); - broken_tests->insert({"sce_sum_log_prob", "result differs"}); - broken_tests->insert({"sce_sum_log_prob_expanded", "result differs"}); - broken_tests->insert({"gridsample_reflection_padding", "result differs"}); broken_tests->insert({"gridsample_volumetric_nearest_align_corners_0", "unknown version"}); broken_tests->insert({"gridsample_volumetric_nearest_align_corners_1", "unknown version"}); broken_tests->insert({"rotary_embedding", "unknown version"}); @@ -1454,9 +1402,7 @@ std::unique_ptr> GetBrokenTests(const std::string& provider broken_tests->insert({"rotary_embedding_no_position_ids_expanded", "unknown version"}); broken_tests->insert({"rotary_embedding_no_position_ids_interleaved", "unknown version"}); broken_tests->insert({"rotary_embedding_no_position_ids_interleaved_expanded", "unknown version"}); - broken_tests->insert({"spacetodepth", "result differs"}); - broken_tests->insert({"reduce_sum_square_empty_set_expanded", "unknown version"}); - // Fails with QNN SDK 2.17.0: + // Fails since QNN SDK 2.17.0: // expected 7.70947 (40f6b3f3), got 7.84096 (40fae920), diff: 0.131491, tol=0.00870947 idx=419. 100 of 1715 differ broken_tests->insert({"facedetection_op8_qdq", "result differs"}); // Fails with QNN SDK 2.34.0: @@ -1466,11 +1412,6 @@ std::unique_ptr> GetBrokenTests(const std::string& provider broken_tests->insert({"mobilenetv2-1.0", "result differs with 2.34"}); broken_tests->insert({"facedetection_op8", "segfault with CPU backend, will be fixed by QNN 2.36"}); -#if defined(_WIN32) && defined(_M_AMD64) - // Fails with QNN SDK 2.17.0 on Windows x64: - // expected 13.5 (41580000), got 0 (0), diff: 13.5, tol=0.0145 idx=3. 3 of 4 differ - broken_tests->insert({"averagepool_2d_ceil", "result differs"}); -#endif // These next 3 Resize tests fail on CPU backend with QNN SDK 2.22.0 due to inaccuracy. // output=Y:expected 1 (3f800000), got 3 (40400000), diff: 2, tol=0.002 idx=24. 8 of 56 differ broken_tests->insert({"resize_upsample_sizes_nearest", "result differs"}); @@ -1482,12 +1423,6 @@ std::unique_ptr> GetBrokenTests(const std::string& provider broken_tests->insert({"convtranspose_group_2_image_3", "Segmentation fault (core dumped). CPU test passed."}); // Fails with QNN 2.31 on Windows x64 for CPU broken_tests->insert({"gelu_tanh_2", "y:expected -0.0131778 (bc57e7d5), got -0.0136333 (bc5f5e38), diff: 0.000455472, tol=2.31778e-05."}); - broken_tests->insert({"convtranspose_pad", "Access violation 0xc000005 from call graphAddNode."}); - broken_tests->insert({"convtranspose_pads", "Access violation 0xc000005 from call graphAddNode."}); - broken_tests->insert({"convtranspose_output_shape", "Access violation 0xc000005 from call graphAddNode."}); - broken_tests->insert({"convtranspose_kernel_shape", "Access violation 0xc000005 from call graphAddNode."}); - broken_tests->insert({"convtranspose_1d", "Access violation 0xc000005 from call graphAddNode."}); - broken_tests->insert({"convtranspose", "Access violation 0xc000005 from call graphAddNode."}); broken_tests->insert({"averagepool_2d_ceil", "result differs. expected 13.5 (41580000), got 0 (0)"}); // Fails with QNN 2.32 broken_tests->insert({"resize_upsample_scales_linear", "expected 1 (3f800000), got 0.25 (3e800000)"}); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index a6a5004a2e2e2..099b8b23dc93d 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2696,6 +2696,43 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_NeedReshape_3D) { 1, pre_graph_checker, post_graph_checker)); } +// With attention pattern, but targeting an execution provider that does not perform +// AttentionFusion, fuse into GEMM should still be happen, rather than skipping them +TEST_F(GraphTransformationTests, MatMulAddFusion_PreserveAttentionPattern) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "matmul_add_fusion/matmul_add_from_attention.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + + // This toy model contains 11 MatMul + Add pairs, 0 GEMMs. + // 7 of them are out of "Attention Pattern" (see MatMulAddFusion::IsAttentionPattern) + // 4 of them are in "Attention Pattern" conditionally skipped by MatMulAddFusion pass + OpCountMap op_count_before = CountOpsInGraph(p_model->MainGraph()); + const InlinedHashSet empty_list = {}; + + // In attention pattern, 4 MatMul + Add pairs should be preserved + Graph& graph = p_model->MainGraph(); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(empty_list, /*preserve_attention_pattern=*/true), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + OpCountMap op_count_cpu_ep = CountOpsInGraph(graph); + constexpr int expected_fusions1 = 11 - 4; + ASSERT_EQ(op_count_cpu_ep["MatMul"], op_count_before["MatMul"] - expected_fusions1); + ASSERT_EQ(op_count_cpu_ep["Add"], op_count_before["Add"] - expected_fusions1); + ASSERT_EQ(op_count_cpu_ep["Gemm"], op_count_before["Gemm"] + expected_fusions1); + + // In attention pattern, 4 MatMul + Add pairs should be fused into Gemm + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(empty_list, /*preserve_attention_pattern=*/false), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + OpCountMap op_count_qnn_ep = CountOpsInGraph(graph); + constexpr int expected_fusions2 = 11; + ASSERT_EQ(op_count_qnn_ep["MatMul"], op_count_before["MatMul"] - expected_fusions2); + ASSERT_EQ(op_count_qnn_ep["Add"], op_count_before["Add"] - expected_fusions2); + ASSERT_EQ(op_count_qnn_ep["Gemm"], op_count_before["Gemm"] + expected_fusions2); +} + #ifndef DISABLE_CONTRIB_OPS TEST_F(GraphTransformationTests, Gemm_Relu_three_input) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/gemm_relu.onnx"; diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index dcbb953a2e05a..5b2865a3feed7 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -539,9 +539,6 @@ TEST(SliceTest, Slice1D_ReverseAllAxes_1) { if (DefaultVSINPUExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Expected output shape [{4}] did not match run output shape [{0}] for output"; } - if (DefaultWebGpuExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Not covered by WebGPU test suite"; - } RunSliceTest({4}, {1.0f, 2.0f, 3.0f, 4.0f}, @@ -556,9 +553,6 @@ TEST(SliceTest, Slice1D_ReverseAllAxes_1) { // With numeric_limit_min, the end value should be clamped to -1 TEST(SliceTest, Slice1D_ReverseAllAxes_2) { - if (DefaultWebGpuExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Not covered by WebGPU test suite"; - } RunSliceTest({4}, {1.0f, 2.0f, 3.0f, 4.0f}, {-1}, @@ -572,9 +566,6 @@ TEST(SliceTest, Slice1D_ReverseAllAxes_2) { // giving an end value < -{dim_value} should also clamp it to -1 TEST(SliceTest, Slice1D_ReverseAllAxes_3) { - if (DefaultWebGpuExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Not covered by WebGPU test suite"; - } RunSliceTest({4}, {1.0f, 2.0f, 3.0f, 4.0f}, {-1}, @@ -591,9 +582,6 @@ TEST(SliceTest, Slice2D_ReverseAllAxes) { if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Expected output shape [{4}] did not match run output shape [{0}] for output"; } - if (DefaultWebGpuExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Not covered by WebGPU test suite"; - } RunSliceTest({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}, @@ -611,9 +599,6 @@ TEST(SliceTest, Slice2D_ReverseSubsetOfAxes_1) { if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; } - if (DefaultWebGpuExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Not covered by WebGPU test suite"; - } RunSliceTest({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}, @@ -631,9 +616,6 @@ TEST(SliceTest, Slice2D_ReverseSubsetOfAxes_2) { if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Expected output shape [{2,2}] did not match run output shape [{0,2}] for output"; } - if (DefaultWebGpuExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Not covered by WebGPU test suite"; - } RunSliceTest({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}, @@ -688,9 +670,6 @@ TEST(SliceTest, Slice2D_ReverseSubsetOfNegAxes_1) { if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Expected output shape [{2,2}] did not match run output shape [{2,0}] for output"; } - if (DefaultWebGpuExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Not covered by WebGPU test suite"; - } RunSliceTest({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}, diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu index 231fedbea31da..8bf3955cce433 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -342,8 +342,6 @@ template void run_blkq4_gemm<32, false, true, false>(int m, int n, int k); template void run_blkq4_gemm<64, false, true, true>(int m, int n, int k); template void run_blkq4_gemm<64, false, true, false>(int m, int n, int k); - - /// @brief Testing small tile GEMM impl template < int block_size, @@ -363,7 +361,7 @@ void run_blkq4_small_gemm(int m, int n, int k) { true>; using QuantBlocking = cutlass::MatrixShape; using LayoutQmeta = typename std::conditional::value, - cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::type; + cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::type; using WarpShape = cutlass::gemm::GemmShape<16, 16, 64>; // change split k to 1 to help debug in case of test failure @@ -383,7 +381,7 @@ void run_blkq4_small_gemm(int m, int n, int k) { const auto meta_shape = cutlass::make_Coord(problem_size.k() / QuantBlocking::kRow, problem_size.n() / QuantBlocking::kColumn); if ((problem_size.k() % QuantBlocking::kRow != 0) || - (problem_size.n() % QuantBlocking::kColumn) != 0){ + (problem_size.n() % QuantBlocking::kColumn) != 0) { ORT_THROW("Test case setup fail: partial quantization block not supported!"); } @@ -422,7 +420,7 @@ void run_blkq4_small_gemm(int m, int n, int k) { cutlass::half_t(1.25), cutlass::half_t(-1.0), 5); // <- Fill matrix A on host with uniform-distribution random data -// std::cout << "========== A: ============ \n" << tensor_a.host_view() << std::endl; + // std::cout << "========== A: ============ \n" << tensor_a.host_view() << std::endl; cutlass::reference::host::TensorFillRandomUniform( tensor_c.host_view(), 1, @@ -531,7 +529,6 @@ template void run_blkq4_small_gemm<64, false, false>(int m, int n, int k); template void run_blkq4_small_gemm<128, false, true>(int m, int n, int k); template void run_blkq4_small_gemm<128, false, false>(int m, int n, int k); - } // namespace test } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index ee0aff6d26444..390329c5cae7a 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -242,12 +242,16 @@ common::Status InternalTestingExecutionProvider::Compile(const std::vector(mem_info); OrtEnv& env = *(OrtEnv*)(*ort_env); + OrtAllocatorImplWrappingIAllocator ort_allocator(std::move(replacement_alloc)); - ASSERT_STATUS_OK(env.RegisterAllocator(replacement_alloc)); + ASSERT_STATUS_OK(env.GetEnvironment().RegisterAllocator(&ort_allocator)); SessionOptions so; ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsConfigUseEnvAllocators, "1")); @@ -284,7 +287,13 @@ TEST(InternalTestingEP, TestReplaceAllocatorDoesntBreakDueToLocalAllocatorStorag ASSERT_STATUS_OK(session.Load(ort_model_path)); ASSERT_STATUS_OK(session.Initialize()); - ASSERT_EQ(replacement_alloc, session.GetAllocator(OrtMemoryInfo())) << "Allocators registered from Env should have the highest priority"; + // Need to undo the wrapping that happens in Environment::RegisterAllocator to be able to compare the pointers + const OrtAllocator* session_allocator = reinterpret_cast( + session.GetAllocator(mem_info).get()) + ->GetWrappedOrtAllocator(); + + ASSERT_EQ(session_allocator, &ort_allocator) + << "Allocators registered from Env should have the highest priority"; } #endif // !defined(DISABLE_CONTRIB_OPS) diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index a96c8d05ee64f..22a80da2f95d2 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -374,7 +374,7 @@ TEST(NnapiExecutionProviderTest, DISABLED_TestQDQMul) { }); } -TEST(NnapiExecutionProviderTest, TestQDQTranspose) { +TEST(NnapiExecutionProviderTest, DISABLED_TestQDQTranspose) { RunQDQModelTest(BuildQDQTransposeTestCase( {1, 3, 32, 32} /* input_shape */, diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index f5f68a20d327c..4718a38ce4e1c 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -35,9 +35,9 @@ class NvExecutionProviderTest : public ::testing::Test { } else if constexpr (std::is_same::value) { dtype_name = "fp32"; } else if constexpr (std::is_same::value) { - dtype_name = "fp16"; - } else if constexpr (std::is_same::value) { dtype_name = "bf16"; + } else if constexpr (std::is_same::value) { + dtype_name = "fp16"; } else if constexpr (std::is_same::value) { dtype_name = "int8"; } else if constexpr (std::is_same::value) { @@ -184,6 +184,7 @@ static void CreateBaseModel(const PathString& model_name, auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()); status = onnxruntime::Model::Save(model, model_name); + ASSERT_TRUE(status.IsOK()); } static Ort::IoBinding generate_io_binding(Ort::Session& session, std::map> shape_overwrites = {}) { diff --git a/onnxruntime/test/providers/qnn/average_pool_test.cc b/onnxruntime/test/providers/qnn/average_pool_test.cc index 7969f4472629a..590694c6fa740 100644 --- a/onnxruntime/test/providers/qnn/average_pool_test.cc +++ b/onnxruntime/test/providers/qnn/average_pool_test.cc @@ -106,6 +106,23 @@ TEST_F(QnnCPUBackendTests, AveragePool_AutopadSameLower) { ExpectedEPNodeAssignment::All); } +// AveragePool 3D as GlobalAveragePool. +TEST_F(QnnCPUBackendTests, AveragePool_3D_AsGlobal) { + RunAveragePoolOpTest("AveragePool", + {TestInputDef({1, 2, 3, 3, 3}, false, -10.0f, 10.0f)}, + {utils::MakeAttribute("kernel_shape", std::vector{3, 3, 3}), + utils::MakeAttribute("strides", std::vector{3, 3, 3})}, + ExpectedEPNodeAssignment::All); +} + +// GlobalAveragePool 3D. +TEST_F(QnnCPUBackendTests, GlobalAveragePool_3D) { + RunAveragePoolOpTest("GlobalAveragePool", + {TestInputDef({1, 2, 3, 3, 3}, false, -10.0f, 10.0f)}, + {}, + ExpectedEPNodeAssignment::All); +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // // HTP tests: @@ -142,9 +159,7 @@ TEST_F(QnnHTPBackendTests, AveragePool_CountIncludePad_HTP_u8) { {utils::MakeAttribute("kernel_shape", std::vector{1, 1}), utils::MakeAttribute("count_include_pad", static_cast(1))}, ExpectedEPNodeAssignment::All, - 18, - // Need tolerance of 0.414% of output range after QNN SDK 2.17 - QDQTolerance(0.00414f)); + 18); } // QDQ AveragePool that use auto_pad 'SAME_UPPER'. @@ -157,9 +172,7 @@ TEST_F(QnnHTPBackendTests, AveragePool_AutopadSameUpper_HTP_u8) { {utils::MakeAttribute("kernel_shape", std::vector{1, 1}), utils::MakeAttribute("auto_pad", "SAME_UPPER")}, ExpectedEPNodeAssignment::All, - 18, - // Need to use tolerance of 0.414% of output range after QNN SDK 2.17 - QDQTolerance(0.00414f)); + 18); } // QDQ AveragePool that use auto_pad 'SAME_LOWER'. @@ -172,9 +185,34 @@ TEST_F(QnnHTPBackendTests, AveragePool_AutopadSameLower_HTP_u8) { {utils::MakeAttribute("kernel_shape", std::vector{1, 1}), utils::MakeAttribute("auto_pad", "SAME_LOWER")}, ExpectedEPNodeAssignment::All, - 18, - // Need to use tolerance of 0.414% of output range after QNN SDK 2.17 - QDQTolerance(0.00414f)); + 18); +} + +// QDQ AveragePool 3D. +TEST_F(QnnHTPBackendTests, AveragePool_3D_u8) { + RunQDQAveragePoolOpTest("AveragePool", + {TestInputDef({1, 2, 8, 8, 8}, false, -10.0f, 10.0f)}, + {utils::MakeAttribute("kernel_shape", std::vector{3, 3, 3}), + utils::MakeAttribute("strides", std::vector{2, 2, 2})}, + ExpectedEPNodeAssignment::All); +} + +// QDQ AveragePool 3D with auto_pad SAME_UPPER. +TEST_F(QnnHTPBackendTests, AveragePool_3D_AutoPad_SAME_UPPER_u8) { + RunQDQAveragePoolOpTest("AveragePool", + {TestInputDef({1, 2, 8, 8, 8}, false, -10.0f, 10.0f)}, + {utils::MakeAttribute("kernel_shape", std::vector{2, 2, 2}), + utils::MakeAttribute("auto_pad", "SAME_UPPER")}, + ExpectedEPNodeAssignment::All); +} + +// QDQ AveragePool 3D with auto_pad SAME_LOWER. +TEST_F(QnnHTPBackendTests, AveragePool_3D_AutoPad_SAME_LOWER_u8) { + RunQDQAveragePoolOpTest("AveragePool", + {TestInputDef({1, 2, 8, 8, 8}, false, -10.0f, 10.0f)}, + {utils::MakeAttribute("kernel_shape", std::vector{2, 2, 2}), + utils::MakeAttribute("auto_pad", "SAME_LOWER")}, + ExpectedEPNodeAssignment::All); } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index 15ac5d3cd6369..ab716382d3a10 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -708,9 +708,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_Test_QDQConvWithDynamicWeightsFromMul) { RunQnnModelTest(BuildConvMulGraph, provider_options, 13, - ExpectedEPNodeAssignment::All, - 4e-4f); // Accuracy decreased slightly in QNN SDK 2.17. - // Expected: 9.94500065, Actual: 9.94537735 + ExpectedEPNodeAssignment::All); } // Check that QNN compiles DQ -> Conv -> Q as a single unit. @@ -727,9 +725,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_bias_dynamic_input) { "NOTSET", ExpectedEPNodeAssignment::All, false, // use_qdq_contrib_ops - 13, // opset - // Need tolerance of 0.413% of output range after QNN SDK 2.17 - QDQTolerance(0.00413f)); + 13); // opset RunHTPConvOpTest("Conv", TestInputDef({1, 1, 5, 5, 5}, false, 0.0f, 10.0f), // Random dynamic input @@ -742,9 +738,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_bias_dynamic_input) { "NOTSET", ExpectedEPNodeAssignment::All, false, // use_qdq_contrib_ops - 13, // opset - // Need tolerance of 0.413% of output range after QNN SDK 2.17 - QDQTolerance(0.00413f)); + 13); // opset } // Test per-channel QDQ Conv. in0: u8, in1 (weight): s8, in2 (bias): s32, out: u8 @@ -1911,9 +1905,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_bias_initializer) { "NOTSET", ExpectedEPNodeAssignment::All, false, // use_qdq_contrib_ops - 13, // opset - // Need tolerance of 0.413% of output range after QNN SDK 2.17 - QDQTolerance(0.00413f)); + 13); // opset RunHTPConvOpTest("Conv", TestInputDef({1, 1, 5, 5, 5}, false, 0.0f, 10.0f), // Random dynamic input @@ -1926,9 +1918,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_bias_initializer) { "NOTSET", ExpectedEPNodeAssignment::All, false, // use_qdq_contrib_ops - 13, // opset - // Need tolerance of 0.413% of output range after QNN SDK 2.17 - QDQTolerance(0.00413f)); + 13); // opset } // Tests 1D Conv with bias as an initializer. @@ -2136,12 +2126,6 @@ TEST_F(QnnHTPBackendTests, DISABLED_ConvU8U8S32_large_input1_padding_bias_initia } TEST_F(QnnHTPBackendTests, ConvU8U8S32_large_input2_bias_initializer) { -#ifdef __linux__ - // On Linux QNN SDK 2.17: Need a tolerance of 0.785% of output range to pass. - QDQTolerance tolerance = QDQTolerance(0.00785f); -#else - QDQTolerance tolerance = QDQTolerance(); -#endif RunHTPConvOpTest("Conv", TestInputDef({1, 128, 8, 56}, false, 0.f, 10.f), // Dynamic input TestInputDef({32, 128, 1, 1}, true, -1.f, 1.f), // Random static weights @@ -2153,8 +2137,7 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_large_input2_bias_initializer) { "NOTSET", ExpectedEPNodeAssignment::All, false, - 13, - tolerance); + 13); } TEST_F(QnnHTPBackendTests, ConvU8U8S32_LargeInput_Dilations_Pads) { diff --git a/onnxruntime/test/providers/qnn/gemm_op_test.cc b/onnxruntime/test/providers/qnn/gemm_op_test.cc index 73665ec1b9bdc..ef19b37c1eb30 100644 --- a/onnxruntime/test/providers/qnn/gemm_op_test.cc +++ b/onnxruntime/test/providers/qnn/gemm_op_test.cc @@ -336,8 +336,7 @@ TEST_F(QnnHTPBackendTests, Gemm_Broadcast_Bias_DynamicInputs) { ExpectedEPNodeAssignment::All, 13, false, - // Require tolerance of 0.74% on Windows ARM64. - QDQTolerance(0.0074f)); + QDQTolerance(0.00410f)); } TEST_F(QnnHTPBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_DynamicC) { @@ -356,8 +355,7 @@ TEST_F(QnnHTPBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_DynamicC) { ExpectedEPNodeAssignment::All, 13, false, - // Require tolerance of 0.74% on Windows ARM64. - QDQTolerance(0.0074f)); + QDQTolerance(0.00410f)); } TEST_F(QnnHTPBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_StaticC) { @@ -376,8 +374,7 @@ TEST_F(QnnHTPBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_StaticC) { ExpectedEPNodeAssignment::All, 13, false, - // Require tolerance of 0.74% on Windows ARM64. - QDQTolerance(0.0074f)); + QDQTolerance(0.00410f)); } // Test 16-bit QDQ Gemm with dynamic inputs A and Bias. The B input is an initializer. diff --git a/onnxruntime/test/providers/qnn/layer_norm_test.cc b/onnxruntime/test/providers/qnn/layer_norm_test.cc index 182877ddf200c..7aa3f030d9f43 100644 --- a/onnxruntime/test/providers/qnn/layer_norm_test.cc +++ b/onnxruntime/test/providers/qnn/layer_norm_test.cc @@ -32,13 +32,7 @@ static void RunLayerNormCpuTest(const TestInputDef& input_def, expected_ep_assignment); } -#ifdef __linux__ -// This CPU test fails on Linux, QNN SDK 2.17 -// the value pair (-1.75661933, 0) at index #1 don't match, which is 1.75662 from -1.75662 -TEST_F(QnnCPUBackendTests, DISABLED_LayerNorm) { -#else TEST_F(QnnCPUBackendTests, LayerNorm) { -#endif RunLayerNormCpuTest(TestInputDef({2, 3}, false, GetFloatDataInRange(0.0f, 10.0f, 6)), TestInputDef({2, 3}, false, GetFloatDataInRange(0.0f, 10.0f, 6)), {utils::MakeAttribute("axis", static_cast(0))}, @@ -210,7 +204,7 @@ TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { // Test accuracy of 8-bit QDQ LayerNorm with a dynamic scale input. // -// TODO(adrianlizarraga): Fails to finalize with QNN SDK 2.22. Still fails on QNN SDK 2.28.2. +// TODO(adrianlizarraga): Fails to finalize with QNN SDK 2.22. Still fails on QNN SDK 2.35.0. // Verbose logs: // Starting stage: Graph Transformations and Optimizations // C:\...\QNN\HTP\HTP\src\hexagon\prepare\graph_prepare.cc:203:ERROR:could not create op: q::flat_to_vtcm diff --git a/onnxruntime/test/providers/qnn/lrn_op_test.cc b/onnxruntime/test/providers/qnn/lrn_op_test.cc index bb3a40a47a750..35ec2cb450691 100644 --- a/onnxruntime/test/providers/qnn/lrn_op_test.cc +++ b/onnxruntime/test/providers/qnn/lrn_op_test.cc @@ -149,20 +149,13 @@ TEST_F(QnnHTPBackendTests, LRNSize5) { } TEST_F(QnnHTPBackendTests, LRN_size_larger_than_channel) { -#ifdef __linux__ - // On Linux QNN SDK 2.17: Need a tolerance of 0.407% of output range to pass. - QDQTolerance tolerance = QDQTolerance(0.00407f); -#else - QDQTolerance tolerance = QDQTolerance(); -#endif RunQDQLRNOpTest(TestInputDef({1, 128, 4, 5}, false, -10.0f, 10.0f), 255, // Size ExpectedEPNodeAssignment::All, 0.0001f, // alpha 0.75f, // beta 1.0f, // bias - 13, // opset - tolerance); + 13); } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index e983c31ab1b75..723717351ea86 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -194,13 +194,7 @@ TEST_F(QnnCPUBackendTests, MatMulOp) { RunMatMulOpTest(false, {2, 3, 3, 3}, {3, 2}, false, true); RunMatMulOpTest(false, {2, 3, 3, 3}, {2, 3, 3, 2}, false, true); -#if defined(__linux__) - // TODO: This fails on Linux (HTP emulation). Works on Windows ARM64. - // Expected: contains 24 values, where each value and its corresponding value in 16-byte object <18-00 00-00 00-00 00-00 00-29 4E-53 A8-55 00-00> are an almost-equal pair - // Actual: 16-byte object <18-00 00-00 00-00 00-00 80-28 3E-53 A8-55 00-00>, where the value pair (0.0285999943, 0) at index #12 don't match, which is -0.0286 from 0.0286 -#else RunMatMulOpTest(false, {2, 1, 2, 3}, {3, 3, 2}, false, false); -#endif RunMatMulOpTest(false, {3}, {3}, false, false); RunMatMulOpTest(false, {3}, {3}, false, true); RunMatMulOpTest(false, {3}, {3}, true, false); @@ -285,7 +279,7 @@ TEST_F(QnnHTPBackendTests, MatMulOp_QDQ) { // UINT16, per-channel INT8 weight RunQDQPerChannelMatMulOpTest({2, 3}, {3, 2}, 1, QDQTolerance(), ExpectedEPNodeAssignment::All, 21, false, false); - RunQDQPerChannelMatMulOpTest({2, 3, 3}, {3}, -1, QDQTolerance(0.005f)); + RunQDQPerChannelMatMulOpTest({2, 3, 3}, {3}, -1, QDQTolerance(0.0041f)); } // Tests MatMul with two uint16 (quantized) inputs that are both dynamic. diff --git a/onnxruntime/test/providers/qnn/pool_op_test.cpp b/onnxruntime/test/providers/qnn/pool_op_test.cpp index 9284df6f8a4a8..d51eeeea1aea8 100644 --- a/onnxruntime/test/providers/qnn/pool_op_test.cpp +++ b/onnxruntime/test/providers/qnn/pool_op_test.cpp @@ -5,15 +5,16 @@ #include #include +#include +#include + +#include "gtest/gtest.h" #include "core/graph/node_attr_utils.h" +#include "core/graph/onnx_protobuf.h" #include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" -#include "core/graph/onnx_protobuf.h" - -#include "gtest/gtest.h" - namespace onnxruntime { namespace test { @@ -145,6 +146,18 @@ TEST_F(QnnCPUBackendTests, DISABLED_MaxPool_Large_Input2_Ceil) { ExpectedEPNodeAssignment::All); } +TEST_F(QnnCPUBackendTests, MaxPool_3D) { + RunPoolOpTest("MaxPool", + TestInputDef({1, 2, 3, 3, 3}, false, -10.0f, 10.0f), + {utils::MakeAttribute("kernel_shape", std::vector{3, 3, 3}), + utils::MakeAttribute("strides", std::vector{3, 3, 3}), + utils::MakeAttribute("pads", std::vector{0, 0, 0, 0, 0, 0}), + utils::MakeAttribute("dilations", std::vector{1, 1, 1}), + utils::MakeAttribute("ceil_mode", static_cast(0)), + utils::MakeAttribute("auto_pad", "NOTSET")}, + ExpectedEPNodeAssignment::All); +} + // GlobalMaxPool test TEST_F(QnnCPUBackendTests, GlobalMaxPoolTest) { RunPoolOpTest("GlobalMaxPool", @@ -153,6 +166,13 @@ TEST_F(QnnCPUBackendTests, GlobalMaxPoolTest) { ExpectedEPNodeAssignment::All); } +TEST_F(QnnCPUBackendTests, GlobalMaxPool_3D) { + RunPoolOpTest("GlobalMaxPool", + TestInputDef({1, 2, 3, 3, 3}, false, -10.0f, 10.0f), + {}, + ExpectedEPNodeAssignment::All); +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // // HTP tests: @@ -182,10 +202,8 @@ TEST_F(QnnHTPBackendTests, MaxPool_Large_Input_HTP_u8) { utils::MakeAttribute("storage_order", static_cast(0)), utils::MakeAttribute("auto_pad", "NOTSET")}, ExpectedEPNodeAssignment::All, - 18, // opset - false, // use_contrib_qdq_ops - // Need a tolerance of 0.417% of output range after QNN SDK 2.17 - QDQTolerance(0.00417f)); + 18, // opset + false); // use_contrib_qdq_ops } TEST_F(QnnHTPBackendTests, MaxPool1D_ReshapeNodesPresent) { @@ -386,6 +404,30 @@ TEST_F(QnnHTPBackendTests, MaxPool_LargeInput_1Pads_u16) { true); // use_contrib_qdq_ops } +// Test uint8 QDQ MaxPool with auto_pad SAME_LOWER. +TEST_F(QnnHTPBackendTests, MaxPool_AutoPad_SAME_LOWER_u8) { + RunQDQPoolOpTest("MaxPool", + TestInputDef({1, 3, 16, 24}, false, -10.0f, 10.0f), + {utils::MakeAttribute("kernel_shape", std::vector{2, 2}), + utils::MakeAttribute("strides", std::vector{2, 2}), + utils::MakeAttribute("auto_pad", "SAME_LOWER")}, + ExpectedEPNodeAssignment::All, + 18, + true); +} + +// Test uint8 QDQ MaxPool with auto_pad SAME_UPPER. +TEST_F(QnnHTPBackendTests, MaxPool_AutoPad_SAME_UPPER_u8) { + RunQDQPoolOpTest("MaxPool", + TestInputDef({1, 3, 16, 24}, false, -10.0f, 10.0f), + {utils::MakeAttribute("kernel_shape", std::vector{2, 2}), + utils::MakeAttribute("strides", std::vector{2, 2}), + utils::MakeAttribute("auto_pad", "SAME_UPPER")}, + ExpectedEPNodeAssignment::All, + 18, + true); +} + // QDQ GlobalMaxPool test TEST_F(QnnHTPBackendTests, GlobalMaxPool_u8) { std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 18); diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 87385b7964d98..4febfe7ba836d 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -748,7 +748,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport2) { QnnContextBinaryMultiPartitionTestBody(single_ep_node); } -void EpCtxCpuNodeWithExternalIniFileTestBody(bool expect_external_ini_file) { +void EpCtxCpuNodeWithExternalIniFileTestBody(bool expect_external_ini_file, bool load_model_from_buffer = false) { ProviderOptions provider_options; provider_options["backend_type"] = "htp"; @@ -787,7 +787,22 @@ void EpCtxCpuNodeWithExternalIniFileTestBody(bool expect_external_ini_file) { so.AddConfigEntry(kOrtSessionOptionsEpContextModelExternalInitializersFileName, external_ini_file.c_str()); } // otherwise all initializers are in Onnx file, no external data file generated - Ort::Session session(*ort_env, ToPathString(model_with_ext).c_str(), so); + if (load_model_from_buffer) { + std::vector buffer; + { + std::ifstream file(model_with_ext, std::ios::binary | std::ios::ate); + if (!file) + ORT_THROW("Error reading model"); + buffer.resize(narrow(file.tellg())); + file.seekg(0, std::ios::beg); + if (!file.read(buffer.data(), buffer.size())) + ORT_THROW("Error reading model"); + } + so.AddConfigEntry(kOrtSessionOptionsModelExternalInitializersFileFolderPath, "./testdata/"); + Ort::Session session(*ort_env, buffer.data(), buffer.size(), so); + } else { + Ort::Session session(*ort_env, ToPathString(model_with_ext).c_str(), so); + } EXPECT_TRUE(std::filesystem::exists(ep_context_model_file.c_str())); if (expect_external_ini_file) { @@ -803,18 +818,25 @@ void EpCtxCpuNodeWithExternalIniFileTestBody(bool expect_external_ini_file) { CleanUpCtxFile(ep_context_model_file); } -// Set the external initializer size threshold to 1024 so FusedMatMul (which fallback on CPU) +// Set the session option "ep.context_model_external_initializers_file_name" so FusedMatMul (which fallback on CPU) // will dump initializer data to external file TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithExternalWeights) { EpCtxCpuNodeWithExternalIniFileTestBody(true); } -// Use the default external initializer size threshold (1024000) so FusedMatMul (which fallback on CPU) -// will NOT dump initializer data to external file +// Without setting the session option "ep.context_model_external_initializers_file_name" +// so FusedMatMul (which fallback on CPU) will NOT dump initializer data to external file TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithoutExternalWeights) { EpCtxCpuNodeWithExternalIniFileTestBody(false); } +// Load model from memory +// Without setting the session option "ep.context_model_external_initializers_file_name" +// so FusedMatMul (which fallback on CPU) will NOT dump initializer data to external file +TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithoutExternalWeightsModelFromMemory) { + EpCtxCpuNodeWithExternalIniFileTestBody(false, true); +} + // Set ep.context_file_path to folder path which is not a valid option, check the error message TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationFolderPathNotExpected) { ProviderOptions provider_options; diff --git a/onnxruntime/test/providers/qnn/resize_test.cc b/onnxruntime/test/providers/qnn/resize_test.cc index 702d4e6eddb1b..415e36b9cb93b 100644 --- a/onnxruntime/test/providers/qnn/resize_test.cc +++ b/onnxruntime/test/providers/qnn/resize_test.cc @@ -336,9 +336,7 @@ TEST_F(QnnHTPBackendTests, Resize_DownSample_Linear_HalfPixel) { RunQDQResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 2}, "linear", "half_pixel", "", ExpectedEPNodeAssignment::All, - 19, - // Need tolerance of 0.539% of output range after QNN SDK 2.17 - QDQTolerance(0.00539f)); + 19); } // Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "pytorch_half_pixel" @@ -348,9 +346,7 @@ TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearPytorchHalfPixel) { RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 8, 8}, "linear", "pytorch_half_pixel", "", ExpectedEPNodeAssignment::All, - 19, - // Need tolerance of 0.609% of output range after QNN SDK 2.17 - QDQTolerance(0.00609f)); + 19); } // Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "half_pixel" @@ -360,9 +356,7 @@ TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearHalfPixel) { RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 8, 8}, "linear", "half_pixel", "", ExpectedEPNodeAssignment::All, - 19, - // Need tolerance of 0.609% of output range after QNN SDK 2.17 - QDQTolerance(0.00609f)); + 19); } // Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "align_corners" @@ -372,9 +366,7 @@ TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearAlignCorners) { RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 8, 8}, "linear", "align_corners", "", ExpectedEPNodeAssignment::All, - 19, - // Need tolerance of 0.533% of output range after QNN SDK 2.17 - QDQTolerance(0.00533f)); + 19); } // Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "asymmetric" @@ -384,9 +376,7 @@ TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearAsymmetric) { RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 8, 8}, "linear", "asymmetric", "", ExpectedEPNodeAssignment::All, - 19, - // Need tolerance of 0.619% of output range after QNN SDK 2.17 - QDQTolerance(0.00619f)); + 19); } // Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "half_pixel", nearest_mode: "round_prefer_floor" diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 761bf63976bec..85f8250f70fc5 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1008,6 +1008,60 @@ TEST_F(QnnHTPBackendTests, Reciprocal_QU8) { ExpectedEPNodeAssignment::All); } +// Test Mean Op on HTP +TEST_F(QnnHTPBackendTests, Mean_TwoInputs) { + std::vector input1 = {1.0f, 2.0f, 3.0f, 4.0f}; + std::vector input2 = {5.0f, 6.0f, 7.0f, 8.0f}; + + RunOpTest("Mean", + { + TestInputDef({4}, false, std::move(input1)), + TestInputDef({4}, false, std::move(input2)), + }, + {}, + 13, // Opset version + ExpectedEPNodeAssignment::All); +} + +// Test Mean Op with multiple inputs on HTP +TEST_F(QnnHTPBackendTests, Mean_FourInputs) { + std::vector input1 = {1.0f, 1.0f, 1.0f, 1.0f}; + std::vector input2 = {2.0f, 2.0f, 2.0f, 2.0f}; + std::vector input3 = {3.0f, 3.0f, 3.0f, 3.0f}; + std::vector input4 = {4.0f, 4.0f, 4.0f, 4.0f}; + + RunOpTest("Mean", + { + TestInputDef({4}, false, std::move(input1)), + TestInputDef({4}, false, std::move(input2)), + TestInputDef({4}, false, std::move(input3)), + TestInputDef({4}, false, std::move(input4)), + }, + {}, + 13, + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, Mean_TwoInputs_QU8) { + RunQDQOpTest("Mean", + {TestInputDef({1, 2, 2}, false, GetFloatDataInRange(0.0f, 10.0f, 4)), + TestInputDef({1, 2, 2}, false, GetFloatDataInRange(10.0f, 20.0f, 4))}, + {}, // No attributes for Mean + 13, // Opset version + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, Mean_FourInputs_QU8) { + RunQDQOpTest("Mean", + {TestInputDef({1, 2, 2}, false, GetFloatDataInRange(0.0f, 10.0f, 4)), + TestInputDef({1, 2, 2}, false, GetFloatDataInRange(10.0f, 20.0f, 4)), + TestInputDef({1, 2, 2}, false, GetFloatDataInRange(20.0f, 30.0f, 4)), + TestInputDef({1, 2, 2}, false, GetFloatDataInRange(30.0f, 40.0f, 4))}, + {}, // No attributes for Mean + 13, // Opset version + ExpectedEPNodeAssignment::All); +} + // Test ScatterND op on HTP TEST_F(QnnHTPBackendTests, ScatterND_int64_int64) { std::vector data = {0, 1, 2, 3}; diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 05a6a433a152d..56cc234a63832 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3384,11 +3384,11 @@ TEST(CApiTest, TestSharedAllocators) { // NOTE: On x86 builds arenas are not supported and will default to using non-arena based allocator ASSERT_TRUE(api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg) == nullptr); - // Test that duplicates are handled + // Registration is always a replace operation std::unique_ptr status_releaser( api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg), api.ReleaseStatus); - ASSERT_FALSE(status_releaser.get() == nullptr); + ASSERT_TRUE(status_releaser.get() == nullptr); { // create session 1 @@ -3427,12 +3427,12 @@ TEST(CApiTest, TestSharedAllocators) { MockedOrtAllocator custom_allocator; ASSERT_TRUE(api.RegisterAllocator(env_ptr, &custom_allocator) == nullptr); - // Test that duplicates are handled + // Registration is always a replace operation std::unique_ptr status_releaser( api.RegisterAllocator(env_ptr, &custom_allocator), api.ReleaseStatus); - ASSERT_FALSE(status_releaser.get() == nullptr); + ASSERT_TRUE(status_releaser.get() == nullptr); { // Keep this scoped to destroy the underlying sessions after use @@ -3499,11 +3499,11 @@ TEST(CApiTest, TestSharedAllocators) { std::vector keys, values; ASSERT_TRUE(api.CreateAndRegisterAllocatorV2(env_ptr, onnxruntime::kCudaExecutionProvider, cuda_meminfo, arena_cfg, keys.data(), values.data(), 0) == nullptr); - // Test that duplicates are handled + // Registration is always a replace operation std::unique_ptr status_releaser( api.CreateAndRegisterAllocatorV2(env_ptr, onnxruntime::kCudaExecutionProvider, cuda_meminfo, arena_cfg, keys.data(), values.data(), 0), api.ReleaseStatus); - ASSERT_FALSE(status_releaser.get() == nullptr); + ASSERT_TRUE(status_releaser.get() == nullptr); { // create session 1 diff --git a/onnxruntime/test/shared_lib/test_run_options.cc b/onnxruntime/test/shared_lib/test_run_options.cc index 1187f2e0d7e7e..ade6d5e43b1b5 100644 --- a/onnxruntime/test/shared_lib/test_run_options.cc +++ b/onnxruntime/test/shared_lib/test_run_options.cc @@ -12,3 +12,10 @@ TEST(CApiTest, run_options) { ASSERT_STREQ(options.GetRunTag(), "abc"); ASSERT_EQ(options.GetRunLogVerbosityLevel(), 1); } + +TEST(CApiTest, run_options_config) { + Ort::RunOptions options; + options.AddConfigEntry("foo", "bar"); + EXPECT_STREQ(options.GetConfigEntry("foo"), "bar"); + EXPECT_EQ(options.GetConfigEntry("not foo"), nullptr); +} diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc index 2d5ffc3c81b0f..8ab58adbeeb74 100644 --- a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +++ b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc @@ -14,7 +14,6 @@ #include "core/framework/ortmemoryinfo.h" #include "cpu/cpu_ops.h" #include "cuda/cuda_ops.h" -#include "rocm/rocm_ops.h" #include "onnxruntime_lite_custom_op.h" static const char* c_OpDomain = "test.customop"; @@ -39,9 +38,6 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA Cuda::RegisterOps(domain); Cuda::RegisterOps(domain_v2); - Rocm::RegisterOps(domain); - Rocm::RegisterOps(domain_v2); - Ort::UnownedSessionOptions session_options(options); session_options.Add(domain); session_options.Add(domain_v2); diff --git a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc deleted file mode 100644 index 807182ee28946..0000000000000 --- a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_ROCM - -#define ORT_API_MANUAL_INIT -#include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT - -#include "core/providers/rocm/rocm_context.h" -#include "onnxruntime_lite_custom_op.h" - -void rocm_add(int64_t, float*, const float*, const float*, hipStream_t compute_stream); - -using namespace Ort::Custom; - -#define CUSTOM_ENFORCE(cond, msg) \ - if (!(cond)) { \ - throw std::runtime_error(msg); \ - } - -namespace Rocm { - -void KernelOne(const Ort::Custom::RocmContext& rocm_ctx, - const Ort::Custom::Tensor& X, - const Ort::Custom::Tensor& Y, - Ort::Custom::Tensor& Z) { - auto input_shape = X.Shape(); - CUSTOM_ENFORCE(rocm_ctx.hip_stream, "failed to fetch hip stream"); - CUSTOM_ENFORCE(rocm_ctx.miopen_handle, "failed to fetch miopen handle"); - CUSTOM_ENFORCE(rocm_ctx.blas_handle, "failed to fetch rocblas handle"); - auto z_raw = Z.Allocate(input_shape); - rocm_add(Z.NumberOfElement(), z_raw, X.Data(), Y.Data(), rocm_ctx.hip_stream); -} - -void RegisterOps(Ort::CustomOpDomain& domain) { - static const std::unique_ptr c_CustomOpOne{Ort::Custom::CreateLiteCustomOp("CustomOpOne", "ROCMExecutionProvider", KernelOne)}; - domain.Add(c_CustomOpOne.get()); -} - -} // namespace Rocm - -#endif diff --git a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h deleted file mode 100644 index d3e9e4040a5c3..0000000000000 --- a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace Rocm { - -#ifdef USE_ROCM - -void RegisterOps(Ort::CustomOpDomain& domain); - -#else - -inline void RegisterOps(Ort::CustomOpDomain&) {} - -#endif - -} // namespace Rocm diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/matmul_add_from_attention.onnx b/onnxruntime/test/testdata/transform/matmul_add_fusion/matmul_add_from_attention.onnx new file mode 100644 index 0000000000000..6a88d63da37df Binary files /dev/null and b/onnxruntime/test/testdata/transform/matmul_add_fusion/matmul_add_from_attention.onnx differ diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 9e563b3342dae..81cb56d34c925 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -231,20 +231,6 @@ std::unique_ptr DefaultArmNNExecutionProvider(bool enable_ar #endif } -std::unique_ptr DefaultRocmExecutionProvider(bool test_tunable_op) { -#ifdef USE_ROCM - OrtROCMProviderOptions provider_options{}; - provider_options.do_copy_in_default_stream = true; - provider_options.tunable_op_enable = test_tunable_op ? 1 : 0; - provider_options.tunable_op_tuning_enable = test_tunable_op ? 1 : 0; - provider_options.tunable_op_max_tuning_duration_ms = 0; - if (auto factory = RocmProviderFactoryCreator::Create(&provider_options)) - return factory->CreateProvider(); -#endif - ORT_UNUSED_PARAMETER(test_tunable_op); - return nullptr; -} - std::unique_ptr DefaultCoreMLExecutionProvider(bool use_mlprogram) { // To manually test CoreML model generation on a non-macOS platform, comment out the `&& defined(__APPLE__)` below. // The test will create a model but execution of it will obviously fail. @@ -346,5 +332,8 @@ std::unique_ptr DefaultDmlExecutionProvider() { return nullptr; } +std::unique_ptr DefaultRocmExecutionProvider(bool) { + return nullptr; +} } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_adam.cu b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_adam.cu index affb0afa91997..49f8cf078c55a 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_adam.cu +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_adam.cu @@ -24,103 +24,102 @@ #define ILP 4 typedef enum { - ADAM_MODE_0 = 0, // L2 regularization mode - ADAM_MODE_1 = 1, // Decoupled weight decay mode (AdamW) as implemented in transformers/AdamW - ADAM_MODE_2 = 2 // Decoupled weight decay mode (AdamW) as implemented in pytorch/AdamW + ADAM_MODE_0 = 0, // L2 regularization mode + ADAM_MODE_1 = 1, // Decoupled weight decay mode (AdamW) as implemented in transformers/AdamW + ADAM_MODE_2 = 2 // Decoupled weight decay mode (AdamW) as implemented in pytorch/AdamW } adamMode_t; using MATH_T = float; template struct AdamFunctor { - __device__ __forceinline__ void operator()(int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<4>& tl, - const float beta1, - const float beta2, - const float epsilon, - const float lr, - const float lr_corrected, - const float bias_correction1, - const float bias_correction2, - adamMode_t mode, - const float decay) - { - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T* g = (T*)tl.addresses[0][tensor_loc]; - g += chunk_idx * chunk_size; - - T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx * chunk_size; - - T* m = (T*)tl.addresses[2][tensor_loc]; - m += chunk_idx * chunk_size; - - T* v = (T*)tl.addresses[3][tensor_loc]; - v += chunk_idx * chunk_size; - - n -= chunk_idx * chunk_size; - - // see note in multi_tensor_scale_kernel.cu - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { - MATH_T r_g[ILP]; - MATH_T r_p[ILP]; - MATH_T r_m[ILP]; - MATH_T r_v[ILP]; + __device__ __forceinline__ void operator()(int chunk_size, + volatile int* noop_gmem, + TensorListMetadata<4>& tl, + const float beta1, + const float beta2, + const float epsilon, + const float lr, + const float lr_corrected, + const float bias_correction1, + const float bias_correction2, + adamMode_t mode, + const float decay) { + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T* g = (T*)tl.addresses[0][tensor_loc]; + g += chunk_idx * chunk_size; + + T* p = (T*)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + T* m = (T*)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + + T* v = (T*)tl.addresses[3][tensor_loc]; + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) { - r_g[ii] = g[i]; - r_p[ii] = p[i]; - r_m[ii] = m[i]; - r_v[ii] = v[i]; - } else { - r_g[ii] = MATH_T(0); - r_p[ii] = MATH_T(0); - r_m[ii] = MATH_T(0); - r_v[ii] = MATH_T(0); - } - } + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + r_p[ii] = p[i]; + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - if (mode == ADAM_MODE_0) { // L2 - r_g[ii] = r_g[ii] + (decay * r_p[ii]); - r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; - r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; - MATH_T denom = sqrtf(r_v[ii]) + epsilon; - r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom); - } else if (mode == ADAM_MODE_1) { // weight decay - // Adapted to be mathematically equivalent to transformers AdamW - r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; - r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; - MATH_T denom = sqrtf(r_v[ii]) + epsilon; - r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom); - r_p[ii] = r_p[ii] - (lr * decay * r_p[ii]); - } else if (mode == ADAM_MODE_2) { - // Adapted to be mathematically equivalent to torch AdamW - r_p[ii] = r_p[ii] - (r_p[ii] * lr * decay); - r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; - r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; - MATH_T denom = (sqrtf(r_v[ii]) / sqrtf(bias_correction2)) + epsilon; - r_p[ii] = r_p[ii] - (lr * r_m[ii]) / (bias_correction1 * denom); - } - } + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T denom = sqrtf(r_v[ii]) + epsilon; + r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom); + } else if (mode == ADAM_MODE_1) { // weight decay + // Adapted to be mathematically equivalent to transformers AdamW + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T denom = sqrtf(r_v[ii]) + epsilon; + r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom); + r_p[ii] = r_p[ii] - (lr * decay * r_p[ii]); + } else if (mode == ADAM_MODE_2) { + // Adapted to be mathematically equivalent to torch AdamW + r_p[ii] = r_p[ii] - (r_p[ii] * lr * decay); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T denom = (sqrtf(r_v[ii]) / sqrtf(bias_correction2)) + epsilon; + r_p[ii] = r_p[ii] - (lr * r_m[ii]) / (bias_correction1 * denom); + } + } #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) { - p[i] = r_p[ii]; - m[i] = r_m[ii]; - v[i] = r_v[ii]; - } - } + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + m[i] = r_m[ii]; + v[i] = r_v[ii]; } + } } + } }; void multi_tensor_adam_cuda(int chunk_size, @@ -133,37 +132,36 @@ void multi_tensor_adam_cuda(int chunk_size, const int step, const int mode, const int bias_correction, - const float weight_decay) -{ - using namespace at; - - // Handle bias correction mode - float bias_correction1 = 1.0, bias_correction2 = 1.0; - float lr_corrected = lr; - if (bias_correction == 1) { - bias_correction1 = 1 - std::pow(beta1, step); - bias_correction2 = 1 - std::pow(beta2, step); - lr_corrected *= std::sqrt(bias_correction2) / bias_correction1; - } - - // Assume single type across p,g,m1,m2 now - DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), - 0, - "adam", - multi_tensor_apply<4>(BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor(), - beta1, - beta2, - epsilon, - lr, - lr_corrected, - bias_correction1, - bias_correction2, - (adamMode_t)mode, - weight_decay);) - - AT_CUDA_CHECK(cudaGetLastError()); + const float weight_decay) { + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0, bias_correction2 = 1.0; + float lr_corrected = lr; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + lr_corrected *= std::sqrt(bias_correction2) / bias_correction1; + } + + // Assume single type across p,g,m1,m2 now + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>(BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + epsilon, + lr, + lr_corrected, + bias_correction1, + bias_correction2, + (adamMode_t)mode, + weight_decay);) + + AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_apply.cuh b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_apply.cuh index fabbda2e58151..ff868bce9d446 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_apply.cuh +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_apply.cuh @@ -23,11 +23,11 @@ constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; template struct TensorListMetadata { - void* addresses[n][depth_to_max_tensors[n - 1]]; - int sizes[depth_to_max_tensors[n - 1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; - int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. - int start_tensor_this_launch; + void* addresses[n][depth_to_max_tensors[n - 1]]; + int sizes[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. + int start_tensor_this_launch; }; template @@ -35,10 +35,9 @@ __global__ void multi_tensor_apply_kernel(int chunk_size, volatile int* noop_flag, T tl, U callable, - ArgTypes... args) -{ - // Hand the chunk information to the user-supplied functor to process however it likes. - callable(chunk_size, noop_flag, tl, args...); + ArgTypes... args) { + // Hand the chunk information to the user-supplied functor to process however it likes. + callable(chunk_size, noop_flag, tl, args...); } template @@ -47,82 +46,81 @@ void multi_tensor_apply(int block_size, const at::Tensor& noop_flag, const std::vector>& tensor_lists, T callable, - ArgTypes... args) -{ - TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); - int len0 = tensor_lists[0].size(); - TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); - auto ref_device = tensor_lists[0][0].device(); - TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); - for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices - { - TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); - for (int t = 0; t < tensor_lists[l].size(); t++) { - // TODO: Print which tensor fails. - bool contiguous_memory = tensor_lists[l][t].is_contiguous(); + ArgTypes... args) { + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); + for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices + { + TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); + for (int t = 0; t < tensor_lists[l].size(); t++) { + // TODO: Print which tensor fails. + bool contiguous_memory = tensor_lists[l][t].is_contiguous(); #ifdef VERSION_GE_1_5 - contiguous_memory = (contiguous_memory || - tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); + contiguous_memory = (contiguous_memory || + tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); #endif - TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); - TORCH_CHECK(tensor_lists[l][t].device() == ref_device, - "A tensor was not on the same device as the first tensor"); - TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); - } + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, + "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); } - - int ntensors = tensor_lists[0].size(); - - TensorListMetadata tl; - - const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); - auto stream = at::cuda::getCurrentCUDAStream(); - - tl.start_tensor_this_launch = 0; - int loc_block_info = 0; - int loc_tensor_info = 0; - for (int t = 0; t < ntensors; t++) { - tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); - for (int d = 0; d < depth; d++) - tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); - loc_tensor_info++; - - int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; - - for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { - // std::cout << chunks_this_tensor << std::endl; - tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; - tl.block_to_chunk[loc_block_info] = chunk; - loc_block_info++; - - bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && - chunk == chunks_this_tensor - 1); - bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); - bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); - if (tensors_full || blocks_full || last_chunk) { - // using accscalar_t = acc_type; - multi_tensor_apply_kernel<<>>( - chunk_size, noop_flag.data_ptr(), tl, callable, args...); - - AT_CUDA_CHECK(cudaGetLastError()); - - // Reset. The control flow possibilities here make my brain hurt. - loc_block_info = 0; - if (chunk == chunks_this_tensor - 1) { - // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << - // std::endl; - loc_tensor_info = 0; - tl.start_tensor_this_launch = t + 1; - } else { - // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << - // std::endl; - tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; - for (int d = 0; d < depth; d++) - tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; - loc_tensor_info = 1; - tl.start_tensor_this_launch = t; - } - } + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); + auto stream = at::cuda::getCurrentCUDAStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for (int t = 0; t < ntensors; t++) { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + loc_tensor_info++; + + int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + // std::cout << chunks_this_tensor << std::endl; + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if (tensors_full || blocks_full || last_chunk) { + // using accscalar_t = acc_type; + multi_tensor_apply_kernel<<>>( + chunk_size, noop_flag.data_ptr(), tl, callable, args...); + + AT_CUDA_CHECK(cudaGetLastError()); + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if (chunk == chunks_this_tensor - 1) { + // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << + // std::endl; + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } else { + // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << + // std::endl; + tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; } + } } + } } diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_axpby_kernel.cu b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_axpby_kernel.cu index 69566cb78a9e3..dd13657d59c6d 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_axpby_kernel.cu +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_axpby_kernel.cu @@ -17,28 +17,26 @@ #define BLOCK_SIZE 512 #define ILP 4 -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; +template +__device__ __forceinline__ bool is_aligned(T* p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; +template +__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) { + typedef typename std::aligned_storage::type LT; ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; } -template -struct AxpbyFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<3>& tl, - float a, - float b, - int arg_to_check) - { +template +struct AxpbyFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, + volatile int* noop_gmem, + TensorListMetadata<3>& tl, + float a, + float b, + int arg_to_check) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -48,15 +46,15 @@ struct AxpbyFunctor int n = tl.sizes[tensor_loc]; x_t* x = (x_t*)tl.addresses[0][tensor_loc]; - x += chunk_idx*chunk_size; + x += chunk_idx * chunk_size; y_t* y = (y_t*)tl.addresses[1][tensor_loc]; - y += chunk_idx*chunk_size; + y += chunk_idx * chunk_size; out_t* out = (out_t*)tl.addresses[2][tensor_loc]; - out += chunk_idx*chunk_size; + out += chunk_idx * chunk_size; - n -= chunk_idx*chunk_size; + n -= chunk_idx * chunk_size; bool finite = true; x_t r_x[ILP]; @@ -64,96 +62,85 @@ struct AxpbyFunctor out_t r_out[ILP]; // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(out)) - { - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(out)) { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load - load_store(r_x, x, 0 , i_start); - load_store(r_y, y, 0 , i_start); + load_store(r_x, x, 0, i_start); + load_store(r_y, y, 0, i_start); #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_out[ii] = a*static_cast(r_x[ii]) + b*static_cast(r_y[ii]); - if(arg_to_check == -1) + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = a * static_cast(r_x[ii]) + b * static_cast(r_y[ii]); + if (arg_to_check == -1) finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii])); - if(arg_to_check == 0) + if (arg_to_check == 0) finite = finite && isfinite(r_x[ii]); - if(arg_to_check == 1) + if (arg_to_check == 1) finite = finite && isfinite(r_y[ii]); } // store - load_store(out, r_out, i_start , 0); + load_store(out, r_out, i_start, 0); } - } - else - { + } else { // Non-divergent exit condition for __syncthreads, not necessary here - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) - { + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { + for (int ii = 0; ii < ILP; ii++) { r_x[ii] = 0; r_y[ii] = 0; - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { r_x[ii] = x[i]; r_y[ii] = y[i]; } } #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_out[ii] = a*static_cast(r_x[ii]) + b*static_cast(r_y[ii]); - if(arg_to_check == -1) + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = a * static_cast(r_x[ii]) + b * static_cast(r_y[ii]); + if (arg_to_check == -1) finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii])); - if(arg_to_check == 0) + if (arg_to_check == 0) finite = finite && isfinite(r_x[ii]); - if(arg_to_check == 1) + if (arg_to_check == 1) finite = finite && isfinite(r_y[ii]); } // see note in multi_tensor_scale_kernel.cu #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) out[i] = r_out[ii]; } } } - if(!finite) - *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. + if (!finite) + *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. } }; void multi_tensor_axpby_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector>& tensor_lists, - float a, - float b, - int arg_to_check) -{ + int chunk_size, + at::Tensor noop_flag, + std::vector>& tensor_lists, + float a, + float b, + int arg_to_check) { using namespace at; // The output (downscaled) type is always float. // If build times suffer, think about where to put this dispatch, // and what logic should be moved out of multi_tensor_apply. DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda", - DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda", - DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda", - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AxpbyFunctor(), - a, - b, - arg_to_check); ))) + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda", + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda", + multi_tensor_apply<3>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + AxpbyFunctor(), + a, + b, + arg_to_check);))) AT_CUDA_CHECK(cudaGetLastError()); diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_scale_kernel.cu b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_scale_kernel.cu index e6cba04432ed0..d15f72ce39d27 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_scale_kernel.cu +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_scale_kernel.cu @@ -19,26 +19,24 @@ #define BLOCK_SIZE 512 #define ILP 4 -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; +template +__device__ __forceinline__ bool is_aligned(T* p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; +template +__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) { + typedef typename std::aligned_storage::type LT; ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; } -template -struct ScaleFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<2>& tl, - float scale) - { +template +struct ScaleFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, + volatile int* noop_gmem, + TensorListMetadata<2>& tl, + float scale) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -48,45 +46,38 @@ struct ScaleFunctor int n = tl.sizes[tensor_loc]; in_t* in = (in_t*)tl.addresses[0][tensor_loc]; - in += chunk_idx*chunk_size; + in += chunk_idx * chunk_size; out_t* out = (out_t*)tl.addresses[1][tensor_loc]; - out += chunk_idx*chunk_size; + out += chunk_idx * chunk_size; - n -= chunk_idx*chunk_size; + n -= chunk_idx * chunk_size; bool finite = true; in_t r_in[ILP]; out_t r_out[ILP]; // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) - { - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load - load_store(r_in, in, 0 , i_start); + load_store(r_in, in, 0, i_start); #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { + for (int ii = 0; ii < ILP; ii++) { r_out[ii] = static_cast(r_in[ii]) * scale; finite = finite && isfinite(r_in[ii]); } // store load_store(out, r_out, i_start, 0); } - } - else - { + } else { // Non-divergent exit condition for __syncthreads, not necessary here - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) - { + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { + for (int ii = 0; ii < ILP; ii++) { r_in[ii] = 0; - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) r_in[ii] = in[i]; } // note for clarification to future michael: @@ -95,45 +86,42 @@ struct ScaleFunctor // Put another way, the STGs are dependent on the LDGs, but not on each other. // There is still compute ILP benefit from unrolling the loop though. #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { + for (int ii = 0; ii < ILP; ii++) { r_out[ii] = static_cast(r_in[ii]) * scale; finite = finite && isfinite(r_in[ii]); } #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) out[i] = r_out[ii]; } } } - if(!finite) - *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. + if (!finite) + *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. } }; void multi_tensor_scale_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector>& tensor_lists, - float scale) -{ + int chunk_size, + at::Tensor noop_flag, + std::vector>& tensor_lists, + float scale) { using namespace at; // The output (downscaled) type is always float. // If build times suffer, think about where to put this dispatch, // and what logic should be moved out of multi_tensor_apply. DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda", - DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda", - multi_tensor_apply<2>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - ScaleFunctor(), - scale); )) + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda", + multi_tensor_apply<2>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + ScaleFunctor(), + scale);)) AT_CUDA_CHECK(cudaGetLastError()); // AT_CUDA_CHECK(cudaDeviceSynchronize()); diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu index dd6a44b9e3b56..71a49ad69d589 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu @@ -84,9 +84,9 @@ struct OP_LeakyReluGrad : public CtxLeakyReluGrad { template void Impl_##name(cudaStream_t stream, const T* lhs_data, const T* rhs_data, T* output_data, const Ctx##name* func_ctx, size_t count); #define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(x) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16) #define ACTIVATION_GRAD_OP_NAME(name) \ diff --git a/orttraining/orttraining/training_ops/cuda/gist/gist_impl.cu b/orttraining/orttraining/training_ops/cuda/gist/gist_impl.cu index 4ea692c88ca1b..861a178ff7d44 100644 --- a/orttraining/orttraining/training_ops/cuda/gist/gist_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/gist/gist_impl.cu @@ -34,13 +34,12 @@ __global__ void _GistPack1EncoderKernel( uint8_t* output_data, const size_t factor, const CUDA_LONG N) { - - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); // id of Y (compressed tensor) + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); // id of Y (compressed tensor) uint8_t out = 0x0; uint8_t bit_out = 0x0; size_t begin = id * factor; size_t end = id * factor + factor; - for(size_t idx = begin; idx < end; idx++){ + for (size_t idx = begin; idx < end; idx++) { bool bit = (input_data[idx] > (T)0); int nidxshift = idx % factor; bit_out = bit ? (0x80 >> nidxshift) : 0; @@ -54,7 +53,7 @@ __global__ void _GistPack1DecoderKernel( T* output_data, const size_t factor, const CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); // id of Y (uncompressed tensor) + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); // id of Y (uncompressed tensor) int nidx = id / factor; int nidxshift = id % factor; uint8_t mask = 0x80 >> nidxshift; @@ -67,7 +66,6 @@ __global__ void _GistPack8EncoderKernel( const T* input_data, uint8_t* output_data, const CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); T X = input_data[id]; @@ -85,7 +83,7 @@ __global__ void _GistPack8EncoderKernel( uint32_t pack_e_size = 5; uint32_t pack_m_size = 2; uint8_t bias = 127; - switch(sizeof(T)){ + switch (sizeof(T)) { case 4: m_size = 23; e_size = 8; @@ -115,16 +113,15 @@ __global__ void _GistPack8EncoderKernel( uint32_t pack_e = e >> pack_e_shift; uint32_t pack_m = m >> pack_m_shift; uint32_t m_residual = m & m_residual_mask; - if(m_residual > 0){ // round up - if(pack_m == 0x3){ - pack_e +=1; // increase exponent + if (m_residual > 0) { // round up + if (pack_m == 0x3) { + pack_e += 1; // increase exponent pack_m = 0; - } - else{ - pack_m +=1; // increase mantissa + } else { + pack_m += 1; // increase mantissa } } - if (pack_e >= 0x1f) { //NaN values + if (pack_e >= 0x1f) { // NaN values pack_e = 0; } output_data[id] = (s << (pack_e_size + pack_m_size)) | (pack_e << pack_m_size) | pack_m; @@ -149,7 +146,7 @@ __global__ void _GistPack8DecoderKernel( uint32_t e_size = 8; uint32_t bias = 127; - switch(sizeof(T)){ + switch (sizeof(T)) { case 4: m_size = 23; e_size = 8; @@ -162,14 +159,14 @@ __global__ void _GistPack8DecoderKernel( break; } uint32_t pack_e_shift = e_size - pack_e_size; - uint32_t s = i >> (pack_e_size+ pack_m_size); + uint32_t s = i >> (pack_e_size + pack_m_size); uint32_t pack_e = i & pack_e_mask; pack_e >>= pack_m_size; uint32_t pack_m = i & pack_m_mask; uint32_t unpack_e = pack_e << (pack_e_shift + m_size); unpack_e += bias; - uint32_t unpack_m = pack_m << (m_size -pack_m_size); - uint32_t unpack = (s << (m_size+e_size)) | unpack_e | unpack_m; + uint32_t unpack_m = pack_m << (m_size - pack_m_size); + uint32_t unpack = (s << (m_size + e_size)) | unpack_e | unpack_m; output_data[id] = (T)__uint_as_float((unsigned int)unpack); } @@ -240,7 +237,6 @@ __global__ void _GistPackMsfp15EncoderKernel( } } - // If inf/nan is found, zero out values if (shared_exp >= 0xff) { for (size_t i = 0; i < tile_size; i++) { @@ -253,7 +249,6 @@ __global__ void _GistPackMsfp15EncoderKernel( return; } - // Copy of shared exponent for packing uint32_t pack_shared_exp = shared_exp; @@ -297,19 +292,19 @@ __global__ void _GistPackMsfp15EncoderKernel( // Store {exponent bit, mantissa} in output uint8_t exp_bit = (pack_shared_exp % 2) << pack_e_shift; pack_shared_exp = pack_shared_exp >> 1; - output_data[in_i] = (uint8_t) (exp_bit | (sign >> (s_shift - pack_s_shift)) | mantissa); + output_data[in_i] = (uint8_t)(exp_bit | (sign >> (s_shift - pack_s_shift)) | mantissa); } } template __global__ void _GistPackMsfp15DecoderKernel( - const uint8_t* input_data, - T* output_data, - const CUDA_LONG num_threads, - const CUDA_LONG pre_axis_size, - const CUDA_LONG axis_size, - const CUDA_LONG num_tiles, - const CUDA_LONG tile_size) { + const uint8_t* input_data, + T* output_data, + const CUDA_LONG num_threads, + const CUDA_LONG pre_axis_size, + const CUDA_LONG axis_size, + const CUDA_LONG num_tiles, + const CUDA_LONG tile_size) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, num_threads); // Quantization parameters @@ -346,7 +341,7 @@ __global__ void _GistPackMsfp15DecoderKernel( // Get sign bit uint32_t sign = X & pack_s_mask; // Get mantissa - uint32_t mantissa = (uint32_t) (X & pack_m_mask); + uint32_t mantissa = (uint32_t)(X & pack_m_mask); if (mantissa == 0) { output_data[in_i] = 0.0; @@ -364,7 +359,7 @@ __global__ void _GistPackMsfp15DecoderKernel( mantissa = mantissa & ((1 << 23) - 1); // Reconstruct float number - uint32_t output = (sign << (s_shift - pack_s_shift)) | (exp << e_shift) | mantissa; + uint32_t output = (sign << (s_shift - pack_s_shift)) | (exp << e_shift) | mantissa; output_data[in_i] = (float)__uint_as_float(output); } } @@ -464,7 +459,6 @@ void GistPackMsfp15EncoderImpl( const size_t pre_axis_size, const size_t axis_size, const size_t tile_size) { - assert(axis_size % tile_size == 0); const int num_tiles = static_cast(axis_size / tile_size); @@ -472,25 +466,23 @@ void GistPackMsfp15EncoderImpl( int blocksPerGrid = (int)(ceil(static_cast(threads) / GridDim::maxThreadsPerBlock)); _GistPackMsfp15EncoderKernel<<>>( - input_data, - output_data, - (CUDA_LONG)threads, - (CUDA_LONG)pre_axis_size, - (CUDA_LONG)axis_size, - (CUDA_LONG)num_tiles, - (CUDA_LONG)tile_size - ); + input_data, + output_data, + (CUDA_LONG)threads, + (CUDA_LONG)pre_axis_size, + (CUDA_LONG)axis_size, + (CUDA_LONG)num_tiles, + (CUDA_LONG)tile_size); } template void GistPackMsfp15DecoderImpl( - cudaStream_t stream, - const uint8_t* input_data, - T* output_data, - const size_t pre_axis_size, - const size_t axis_size, - const size_t tile_size) { - + cudaStream_t stream, + const uint8_t* input_data, + T* output_data, + const size_t pre_axis_size, + const size_t axis_size, + const size_t tile_size) { assert(axis_size % tile_size == 0); const int num_tiles = static_cast(axis_size / tile_size); @@ -498,14 +490,13 @@ void GistPackMsfp15DecoderImpl( int blocksPerGrid = (int)(ceil(static_cast(threads) / GridDim::maxThreadsPerBlock)); _GistPackMsfp15DecoderKernel<<>>( - input_data, - output_data, - (CUDA_LONG)threads, - (CUDA_LONG)pre_axis_size, - (CUDA_LONG)axis_size, - (CUDA_LONG)num_tiles, - (CUDA_LONG)tile_size - ); + input_data, + output_data, + (CUDA_LONG)threads, + (CUDA_LONG)pre_axis_size, + (CUDA_LONG)axis_size, + (CUDA_LONG)num_tiles, + (CUDA_LONG)tile_size); } #define SPECIALIZED_IMPL_BIN_ENC(T) \ diff --git a/orttraining/orttraining/training_ops/cuda/math/isfinite_impl.cu b/orttraining/orttraining/training_ops/cuda/math/isfinite_impl.cu index f4a1a95e2af7c..93c5bb62a7474 100644 --- a/orttraining/orttraining/training_ops/cuda/math/isfinite_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/math/isfinite_impl.cu @@ -6,7 +6,6 @@ #include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/math/isfinite.cuh" - namespace onnxruntime { namespace cuda { diff --git a/orttraining/orttraining/training_ops/cuda/math/mixed_precision_scale_impl.cu b/orttraining/orttraining/training_ops/cuda/math/mixed_precision_scale_impl.cu index 0a32141eeaefe..63597cdffbbf1 100644 --- a/orttraining/orttraining/training_ops/cuda/math/mixed_precision_scale_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/math/mixed_precision_scale_impl.cu @@ -28,7 +28,7 @@ void Impl_MixedPrecisionScale( const SrcT* input_data, const float* scale_data, DstT* output_data, - size_t count){ + size_t count) { int blocksPerGrid = static_cast(CeilDiv(count, GridDim::maxThreadsPerBlock)); CUDA_LONG N = static_cast(count); _MixedPrecisionScale<<>>( @@ -39,12 +39,12 @@ void Impl_MixedPrecisionScale( } #define SPECIALIZE_MIXEDPRECISIONSCALE_IMPL(SrcT, DstT) \ -template void Impl_MixedPrecisionScale( \ - cudaStream_t stream, \ - const SrcT* input_data, \ - const float* scale_data, \ - DstT* output_data, \ - size_t count); + template void Impl_MixedPrecisionScale( \ + cudaStream_t stream, \ + const SrcT* input_data, \ + const float* scale_data, \ + DstT* output_data, \ + size_t count); SPECIALIZE_MIXEDPRECISIONSCALE_IMPL(half, half) SPECIALIZE_MIXEDPRECISIONSCALE_IMPL(half, float) diff --git a/orttraining/orttraining/training_ops/cuda/math/scale_impl.cu b/orttraining/orttraining/training_ops/cuda/math/scale_impl.cu index e1e4fcc968470..376f9fc6cfcd5 100644 --- a/orttraining/orttraining/training_ops/cuda/math/scale_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/math/scale_impl.cu @@ -19,7 +19,7 @@ __global__ void _Scale( #pragma unroll for (int i = 0; i < NumElementsPerThread; i++) { if (id < N) { - input_value[i] = input_data[id]; + input_value[i] = input_data[id]; id += NumThreadsPerBlock; } } @@ -50,13 +50,13 @@ void Impl_Scale( N); } -#define SPECIALIZE_SCALE_IMPL(T) \ -template void Impl_Scale( \ - cudaStream_t stream, \ - const T* input_data, \ - const float scale_value, \ - T* output_data, \ - size_t count); +#define SPECIALIZE_SCALE_IMPL(T) \ + template void Impl_Scale( \ + cudaStream_t stream, \ + const T* input_data, \ + const float scale_value, \ + T* output_data, \ + size_t count); SPECIALIZE_SCALE_IMPL(half) SPECIALIZE_SCALE_IMPL(float) diff --git a/orttraining/orttraining/training_ops/cuda/math/softmax_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/math/softmax_grad_impl.cu index 48ec60017a3cd..3b5bd895c1f54 100644 --- a/orttraining/orttraining/training_ops/cuda/math/softmax_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/math/softmax_grad_impl.cu @@ -1,18 +1,18 @@ /** -* Copyright (c) 2016-present, Facebook, Inc. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ /* Modifications Copyright (c) Microsoft. */ @@ -28,105 +28,104 @@ namespace onnxruntime { namespace cuda { - template - __global__ void softmax_warp_backward(output_t* gradInput, const input_t* grad, const input_t* output, - int element_count, int batch_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method - // warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - #ifdef USE_ROCM - constexpr int WARP_BATCH = 1; - #else - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - #endif - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // batch_count might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_count - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x % WARP_SIZE; - - // the first element to process by the current thread - int thread_offset = first_batch * element_count + local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop, - // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep - // the nested loops. - // This should have no impact on performance because the loops are unrolled anyway. - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; - acc_t grad_output_reg[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - grad_reg[i][it] = grad[i * element_count + it * WARP_SIZE]; - output_reg[i][it] = output[i * element_count + it * WARP_SIZE]; - grad_output_reg[i][it] = grad_reg[i][it] * output_reg[i][it]; - } else { - grad_reg[i][it] = acc_t(0); - output_reg[i][it] = acc_t(0); - grad_output_reg[i][it] = acc_t(0); - } +template +__global__ void softmax_warp_backward(output_t* gradInput, const input_t* grad, const input_t* output, + int element_count, int batch_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method + // warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; +#ifdef USE_ROCM + constexpr int WARP_BATCH = 1; +#else + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; +#endif + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // batch_count might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = batch_count - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x % WARP_SIZE; + + // the first element to process by the current thread + int thread_offset = first_batch * element_count + local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop, + // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep + // the nested loops. + // This should have no impact on performance because the loops are unrolled anyway. + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; + acc_t grad_output_reg[WARP_BATCH][WARP_ITERATIONS]; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + grad_reg[i][it] = grad[i * element_count + it * WARP_SIZE]; + output_reg[i][it] = output[i * element_count + it * WARP_SIZE]; + grad_output_reg[i][it] = grad_reg[i][it] * output_reg[i][it]; + } else { + grad_reg[i][it] = acc_t(0); + output_reg[i][it] = acc_t(0); + grad_output_reg[i][it] = acc_t(0); } } + } - acc_t sum[WARP_BATCH]; - if (!is_log_softmax) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_output_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_output_reg[i][it]; - } + acc_t sum[WARP_BATCH]; + if (!is_log_softmax) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_output_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_output_reg[i][it]; } - warp_reduce(sum); } - else { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } + warp_reduce(sum); + } else { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; } - warp_reduce(sum); } + warp_reduce(sum); + } - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - if (is_log_softmax) { - gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); - } else { - gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - sum[i] ) * output_reg[i][it]; - } +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + if (is_log_softmax) { + gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); + } else { + gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - sum[i]) * output_reg[i][it]; } } } } +} // The function "softmax_warp_backward" saves intermediate results in float32 using registers to prevent recomputing, which can be beneficial for small shapes. // However, for larger shapes, the usage of a large register resource can lead to low CUDA warp occupancy and poor performance. @@ -134,7 +133,7 @@ namespace cuda { // TODO: If the dimension to do softmax is greater than 2048, saving the input into shared memory can further reduce register usage. template __global__ void softmax_warp_backward_register_efficicent(output_t* gradInput, const input_t* grad, const input_t* output, - int element_count, int batch_count) { + int element_count, int batch_count) { // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method // warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; @@ -183,21 +182,20 @@ __global__ void softmax_warp_backward_register_efficicent(output_t* gradInput, c acc_t sum[WARP_BATCH]; if (!is_log_softmax) { - #pragma unroll +#pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] = (acc_t)(grad_reg[i][0]) * (acc_t)(output_reg[i][0]); - #pragma unroll +#pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { sum[i] += (acc_t)(grad_reg[i][it]) * (acc_t)(output_reg[i][it]); } } warp_reduce(sum); - } - else { - #pragma unroll + } else { +#pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] = (acc_t)grad_reg[i][0]; - #pragma unroll +#pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { sum[i] += (acc_t)grad_reg[i][it]; } @@ -218,7 +216,7 @@ __global__ void softmax_warp_backward_register_efficicent(output_t* gradInput, c if (is_log_softmax) { gradInput[i * element_count + it * WARP_SIZE] = ((acc_t)(grad_reg[i][it]) - std::exp((acc_t)output_reg[i][it]) * sum[i]); } else { - gradInput[i * element_count + it * WARP_SIZE] = ((acc_t)grad_reg[i][it] - sum[i] ) * (acc_t)output_reg[i][it]; + gradInput[i * element_count + it * WARP_SIZE] = ((acc_t)grad_reg[i][it] - sum[i]) * (acc_t)output_reg[i][it]; } } } @@ -257,22 +255,22 @@ Status SoftmaxGradImpl(cudaStream_t stream, cudnnHandle_t cudnn_handle, T* input // Launch code would be more elegant if C++ supported FOR CONSTEXPR constexpr int start_to_use_register_efficient_func = 11; switch (log2_elements) { -#define LAUNCH_KERNEL(log2_elements_value, kernel_name) \ - if (is_log_softmax) { \ - kernel_name \ - <<>>(input_grad, output_grad, softmax_output, element_count, batch_count); \ - } else { \ - kernel_name \ - <<>>(input_grad, output_grad, softmax_output, element_count, batch_count); \ +#define LAUNCH_KERNEL(log2_elements_value, kernel_name) \ + if (is_log_softmax) { \ + kernel_name \ + <<>>(input_grad, output_grad, softmax_output, element_count, batch_count); \ + } else { \ + kernel_name \ + <<>>(input_grad, output_grad, softmax_output, element_count, batch_count); \ } -#define CASE_LOG2_ELEMENTS(log2_elements_value) \ - case log2_elements_value: { \ - if constexpr (log2_elements_value < start_to_use_register_efficient_func) { \ - LAUNCH_KERNEL(log2_elements_value, softmax_warp_backward); \ - } else { \ - LAUNCH_KERNEL(log2_elements_value, softmax_warp_backward_register_efficicent); \ - } \ +#define CASE_LOG2_ELEMENTS(log2_elements_value) \ + case log2_elements_value: { \ + if constexpr (log2_elements_value < start_to_use_register_efficient_func) { \ + LAUNCH_KERNEL(log2_elements_value, softmax_warp_backward); \ + } else { \ + LAUNCH_KERNEL(log2_elements_value, softmax_warp_backward_register_efficicent); \ + } \ } break CASE_LOG2_ELEMENTS(0); // 1 @@ -316,5 +314,5 @@ SPECIALIZED_SOFTMAX_GRAD_IMPL(double) #endif #undef SPECIALIZED_SOFTMAX_GRAD_IMPL -} -} +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu index ad577afa06c18..29af6b0c251d9 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu @@ -1,18 +1,18 @@ /** -* Copyright (c) 2016-present, Facebook, Inc. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ // // Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. @@ -31,39 +31,39 @@ namespace onnxruntime { namespace cuda { namespace { - // This is the un-specialized struct. Note that we prevent instantiation of this - // struct by putting an undefined symbol in the function body so it won't compile. - // template - // struct SharedMemory - // { - // // Ensure that we won't compile any un-specialized types - // __device__ T *getPointer() - // { - // extern __device__ void error(void); - // error(); - // return NULL; - // } - // }; - // https://github.com/NVIDIA/apex/issues/246 - template - struct SharedMemory; +// This is the un-specialized struct. Note that we prevent instantiation of this +// struct by putting an undefined symbol in the function body so it won't compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; - template <> - struct SharedMemory { - __device__ float* getPointer() { - extern __shared__ float s_float[]; - return s_float; - } - }; +template <> +struct SharedMemory { + __device__ float* getPointer() { + extern __shared__ float s_float[]; + return s_float; + } +}; - template <> - struct SharedMemory { - __device__ double* getPointer() { - extern __shared__ double s_double[]; - return s_double; - } - }; - } // namespace +template <> +struct SharedMemory { + __device__ double* getPointer() { + extern __shared__ double s_double[]; + return s_double; + } +}; +} // namespace template __device__ void cuLoadWriteStridedInputs( @@ -189,12 +189,12 @@ __global__ void cuComputePartGradGammaBeta( // compute partial sums from strided inputs // do this to increase number of loads in flight cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, - row_stride, warp_buf1, warp_buf2, input, output, dout, - i1_end, n2, gamma, beta, mean, invvar); + row_stride, warp_buf1, warp_buf2, input, output, dout, + i1_end, n2, gamma, beta, mean, invvar); for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; i1_block += blockDim.y * blockDim.y) { cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, - row_stride, warp_buf1, warp_buf2, input, output, dout, - i1_end, n2, gamma, beta, mean, invvar); + row_stride, warp_buf1, warp_buf2, input, output, dout, + i1_end, n2, gamma, beta, mean, invvar); } __syncthreads(); // inter-warp reductions @@ -303,7 +303,7 @@ __global__ void cuComputeGradInput( const U c_mean = (use_mean && !simplified) ? mean[i1] : U(0); const U c_invvar = invvar[i1]; const T* k_input = use_mean ? input + i1 * n2 : nullptr; - const V* k_output = use_mean ? nullptr: output + i1 * n2; + const V* k_output = use_mean ? nullptr : output + i1 * n2; const V* k_dout = dout + i1 * n2; const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; @@ -473,23 +473,23 @@ __global__ void cuComputeGradInput( template void HostLayerNormGradient( - const cudaDeviceProp& prop, - cudaStream_t stream, - const V* dout, - const T* input, - const V* output, - const V* gamma, - const V* beta, - const U* mean, - const U* invvar, - int64_t n1, - int64_t n2, - T* grad_input, - V* grad_gamma, - V* grad_beta, - U* part_grad_gamma, - U* part_grad_beta, - const int part_size) { + const cudaDeviceProp& prop, + cudaStream_t stream, + const V* dout, + const T* input, + const V* output, + const V* gamma, + const V* beta, + const U* mean, + const U* invvar, + int64_t n1, + int64_t n2, + T* grad_input, + V* grad_gamma, + V* grad_beta, + U* part_grad_gamma, + U* part_grad_beta, + const int part_size) { const int warp_size = prop.warpSize; ORT_ENFORCE(warp_size == GPU_WARP_SIZE_HOST); @@ -501,30 +501,30 @@ void HostLayerNormGradient( if (mean == nullptr && !simplified) { // use_mean == false, simplified == false -> Inverted Layer Norm cuComputePartGradGammaBeta<<>>( - dout, - input, - output, - gamma, - beta, - mean, - invvar, - n1, n2, - part_grad_gamma, - part_grad_beta); + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + part_grad_gamma, + part_grad_beta); } else { // use_mean == true, simplified == false -> Layer Norm // use_mean == true, simplified == true -> Simplified Layer Norm cuComputePartGradGammaBeta<<>>( - dout, - input, - output, - gamma, - beta, - mean, - invvar, - n1, n2, - part_grad_gamma, - part_grad_beta); + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + part_grad_gamma, + part_grad_beta); } const dim3 threads3(warp_size, 8, 1); const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); @@ -549,50 +549,50 @@ void HostLayerNormGradient( if (mean == nullptr && !simplified) { if (gamma == nullptr) { cuComputeGradInput<<>>( - dout, - input, - output, - gamma, - beta, - mean, - invvar, - n1, n2, - grad_input); + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + grad_input); } else { cuComputeGradInput<<>>( - dout, - input, - output, - gamma, - beta, - mean, - invvar, - n1, n2, - grad_input); + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + grad_input); } } else { if (gamma == nullptr) { cuComputeGradInput<<>>( - dout, - input, - output, - gamma, - beta, - mean, - invvar, - n1, n2, - grad_input); + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + grad_input); } else { cuComputeGradInput<<>>( - dout, - input, - output, - gamma, - beta, - mean, - invvar, - n1, n2, - grad_input); + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + grad_input); } } } diff --git a/orttraining/orttraining/training_ops/cuda/optimizer/sg_impl.cu b/orttraining/orttraining/training_ops/cuda/optimizer/sg_impl.cu index 1f23309e883e5..0a17449ef29f6 100644 --- a/orttraining/orttraining/training_ops/cuda/optimizer/sg_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/optimizer/sg_impl.cu @@ -51,7 +51,7 @@ void SGDOptimizerImpl( #define SPECIALIZED_IMPL__SGDOptimizerImpl(T) \ template void SGDOptimizerImpl( \ - cudaStream_t stream, \ + cudaStream_t stream, \ const T* eta, \ const T* weights, \ const T* gradients, \ diff --git a/orttraining/orttraining/training_ops/cuda/reduction/all_impl.cu b/orttraining/orttraining/training_ops/cuda/reduction/all_impl.cu index 0da76b0a4bec5..638c7d663737e 100644 --- a/orttraining/orttraining/training_ops/cuda/reduction/all_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/reduction/all_impl.cu @@ -21,13 +21,11 @@ __global__ void assign_false(bool* ptr) { *ptr = false; } -template<> +template <> void LaunchAllKernel(cudaStream_t stream, const bool* data, const int size, bool* output) { - if(thrust::all_of(thrust::cuda::par.on(stream), data, data + size, thrust::identity())) { + if (thrust::all_of(thrust::cuda::par.on(stream), data, data + size, thrust::identity())) { assign_true<<<1, 1, 0, stream>>>(output); - } - else - { + } else { assign_false<<<1, 1, 0, stream>>>(output); } } diff --git a/orttraining/orttraining/training_ops/cuda/reduction/reduction_all_impl.cu b/orttraining/orttraining/training_ops/cuda/reduction/reduction_all_impl.cu index 3d17755c86a7f..7c19535388ae2 100644 --- a/orttraining/orttraining/training_ops/cuda/reduction/reduction_all_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/reduction/reduction_all_impl.cu @@ -30,7 +30,7 @@ template void ScalarSqrt(cudaStream_t stream, float* input, BFloat16* output); template __launch_bounds__(ChunkGroup<1>::thread_count_per_block) -__global__ void MultiTensorReduceKernel(ChunkGroup<1> chunk_group, TOut* output) { + __global__ void MultiTensorReduceKernel(ChunkGroup<1> chunk_group, TOut* output) { const int group_index = chunk_group.block_index_to_tensor_group_index[blockIdx.x]; const int tensor_size = chunk_group.tensor_sizes[group_index]; const int chunk_size = chunk_group.chunk_size; diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu index 2091a7082ee79..4e0e7fb7b8e21 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu @@ -61,7 +61,7 @@ void FlattenAndUnpadImpl(cudaStream_t stream, output_data); } -#define FLATTEN_AND_UNPAD_IMPL(T) \ +#define FLATTEN_AND_UNPAD_IMPL(T) \ template void FlattenAndUnpadImpl(cudaStream_t stream, \ const int64_t total_element_count, \ const fast_divmod output_element_stride_fdm, \ diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/gather_grad_impl.cu index bb7b5d2860bca..acbbe0d94d999 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_grad_impl.cu @@ -70,14 +70,14 @@ void GetSortedIndices( nullptr, temp_storage_size_bytes, dX_indices, dX_indices_sorted.get(), dY_indices.get(), dY_indices_sorted.get(), - num_gathered_indices, 0, sizeof(TIndex)*8, stream)); + num_gathered_indices, 0, sizeof(TIndex) * 8, stream)); auto temp_storage = allocator.GetScratchBuffer(temp_storage_size_bytes); CUDA_CALL_THROW(cub::DeviceRadixSort::SortPairs( temp_storage.get(), temp_storage_size_bytes, dX_indices, dX_indices_sorted.get(), dY_indices.get(), dY_indices_sorted.get(), - num_gathered_indices, 0, sizeof(TIndex)*8, stream)); + num_gathered_indices, 0, sizeof(TIndex) * 8, stream)); dX_indices_sorted_out = std::move(dX_indices_sorted); dY_indices_sorted_out = std::move(dY_indices_sorted); diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.cu index e96770f974bf0..52d26a922a719 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.cu @@ -61,7 +61,7 @@ void PadAndUnflattenImpl(cudaStream_t stream, output_data); } -#define PAD_AND_UNFLATTEN_IMPL(T) \ +#define PAD_AND_UNFLATTEN_IMPL(T) \ template void PadAndUnflattenImpl(cudaStream_t stream, \ const int64_t total_element_count, \ const fast_divmod output_element_stride_fdm, \ diff --git a/orttraining/orttraining/training_ops/rocm/activation/gelu_grad_impl_common.cuh b/orttraining/orttraining/training_ops/rocm/activation/gelu_grad_impl_common.cuh index 4d6f3291b139c..2377aae9abb54 100644 --- a/orttraining/orttraining/training_ops/rocm/activation/gelu_grad_impl_common.cuh +++ b/orttraining/orttraining/training_ops/rocm/activation/gelu_grad_impl_common.cuh @@ -24,7 +24,7 @@ __device__ __inline__ T ComputeGeluGradScalar(T dY, T X, gelu_computation_mode:: const float sqrt_param = 0.79788456080286535587989211986876f; const float mul_param = 0.044715f; - + constexpr float one = 1.0; constexpr float two = 2.0; @@ -33,7 +33,7 @@ __device__ __inline__ T ComputeGeluGradScalar(T dY, T X, gelu_computation_mode:: // float tan_h = tanhf(sqrt_param * (X_float + X_float * x2mul)); float u = two * sqrt_param * (X_float + X_float * x2mul); float emu = __expf(-u); - float tan_h = two/(one + emu) - one; + float tan_h = two / (one + emu) - one; float dg1 = 0.5f * (1.0f + tan_h); float dg2 = X_float * 0.5f * sqrt_param * (1 - tan_h * tan_h); diff --git a/setup.py b/setup.py index 37f615e9a5f8a..c13f8160cb112 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,6 @@ def parse_arg_remove_string(argv, arg_name_equal): is_cuda_version_12 = False rocm_version = None is_migraphx = False -is_rocm = False is_openvino = False is_qnn = False # The following arguments are mutually exclusive @@ -66,11 +65,6 @@ def parse_arg_remove_string(argv, arg_name_equal): cuda_version = parse_arg_remove_string(sys.argv, "--cuda_version=") if cuda_version: is_cuda_version_12 = cuda_version.startswith("12.") -elif parse_arg_remove_boolean(sys.argv, "--use_rocm"): - is_rocm = True - rocm_version = parse_arg_remove_string(sys.argv, "--rocm_version=") - if parse_arg_remove_boolean(sys.argv, "--use_migraphx"): - is_migraphx = True elif parse_arg_remove_boolean(sys.argv, "--use_migraphx"): is_migraphx = True elif parse_arg_remove_boolean(sys.argv, "--use_openvino"): @@ -94,9 +88,6 @@ def parse_arg_remove_string(argv, arg_name_equal): elif parse_arg_remove_boolean(sys.argv, "--use_qnn"): is_qnn = True package_name = "onnxruntime-qnn" - -if is_rocm: - package_name = "onnxruntime-rocm" if not nightly_build else "ort-rocm-nightly" elif is_migraphx: package_name = "onnxruntime-migraphx" if not nightly_build else "ort-migraphx-nightly" @@ -314,7 +305,7 @@ def finalize_options(self): return ret -providers_cuda_or_rocm = "onnxruntime_providers_" + ("rocm" if is_rocm else "cuda") +providers_cuda_or_rocm = "onnxruntime_providers_cuda" providers_tensorrt_or_migraphx = "onnxruntime_providers_" + ("migraphx" if is_migraphx else "tensorrt") providers_nv_tensorrt_rtx = "onnxruntime_providers_nv_tensorrt_rtx" providers_openvino = "onnxruntime_providers_openvino" @@ -541,7 +532,6 @@ def finalize_options(self): local_version = None enable_training = parse_arg_remove_boolean(sys.argv, "--enable_training") enable_training_apis = parse_arg_remove_boolean(sys.argv, "--enable_training_apis") -enable_rocm_profiling = parse_arg_remove_boolean(sys.argv, "--enable_rocm_profiling") disable_auditwheel_repair = parse_arg_remove_boolean(sys.argv, "--disable_auditwheel_repair") default_training_package_device = parse_arg_remove_boolean(sys.argv, "--default_training_package_device") @@ -722,8 +712,6 @@ def reformat_run_count(count_str): if local_version: version_number = version_number + local_version - if is_rocm and enable_rocm_profiling: - version_number = version_number + ".profiling" if wheel_name_suffix: if not (enable_training and wheel_name_suffix == "gpu"): diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 4c1ddd94fda5a..f0782dff23345 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -486,7 +486,6 @@ def generate_build_tree( "-Donnxruntime_ENABLE_CPU_FP16_OPS=" + ("ON" if args.enable_training else "OFF"), "-Donnxruntime_USE_NCCL=" + ("ON" if args.enable_nccl else "OFF"), "-Donnxruntime_BUILD_BENCHMARKS=" + ("ON" if args.build_micro_benchmarks else "OFF"), - "-Donnxruntime_USE_ROCM=" + ("ON" if args.use_rocm else "OFF"), "-Donnxruntime_GCOV_COVERAGE=" + ("ON" if args.code_coverage else "OFF"), "-Donnxruntime_ENABLE_MEMORY_PROFILE=" + ("ON" if args.enable_memory_profile else "OFF"), "-Donnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO=" + ("ON" if args.enable_cuda_line_info else "OFF"), @@ -500,12 +499,10 @@ def generate_build_tree( + ("ON" if args.enable_wasm_exception_throwing_override else "OFF"), "-Donnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER=" + ("ON" if args.wasm_run_tests_in_browser else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_THREADS=" + ("ON" if args.enable_wasm_threads else "OFF"), - "-Donnxruntime_ENABLE_WEBASSEMBLY_MEMORY64=" + ("ON" if args.enable_wasm_memory64 else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO=" + ("ON" if args.enable_wasm_debug_info else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_PROFILING=" + ("ON" if args.enable_wasm_profiling else "OFF"), "-Donnxruntime_ENABLE_LAZY_TENSOR=" + ("ON" if args.enable_lazy_tensor else "OFF"), "-Donnxruntime_ENABLE_CUDA_PROFILING=" + ("ON" if args.enable_cuda_profiling else "OFF"), - "-Donnxruntime_ENABLE_ROCM_PROFILING=" + ("ON" if args.enable_rocm_profiling else "OFF"), "-Donnxruntime_USE_XNNPACK=" + ("ON" if args.use_xnnpack else "OFF"), "-Donnxruntime_USE_WEBNN=" + ("ON" if args.use_webnn else "OFF"), "-Donnxruntime_USE_CANN=" + ("ON" if args.use_cann else "OFF"), @@ -559,7 +556,7 @@ def generate_build_tree( vcpkg_installation_root = os.path.join(os.path.abspath(build_dir), "vcpkg") if not os.path.exists(vcpkg_installation_root): run_subprocess( - ["git", "clone", "-b", "2025.03.19", "https://github.com/microsoft/vcpkg.git", "--recursive"], + ["git", "clone", "-b", "2025.06.13", "https://github.com/microsoft/vcpkg.git", "--recursive"], cwd=build_dir, ) vcpkg_toolchain_path = Path(vcpkg_installation_root) / "scripts" / "buildsystems" / "vcpkg.cmake" @@ -647,12 +644,7 @@ def generate_build_tree( # Choose the cmake triplet triplet = None if args.build_wasm: - # The support for wasm64 is still in development. - if args.enable_wasm_memory64: - # The triplet wasm64-emscripten doesn't exist in vcpkg's official repo. - triplet = "wasm64-emscripten" - else: - triplet = "wasm32-emscripten" + triplet = "wasm32-emscripten" elif args.android: if args.android_abi == "armeabi-v7a": triplet = "arm-neon-android" @@ -708,8 +700,6 @@ def generate_build_tree( cmake_args.append("-DCMAKE_C_COMPILER_LAUNCHER=ccache") if args.use_cuda: cmake_args.append("-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache") - if args.use_rocm: - cmake_args.append("-DCMAKE_HIP_COMPILER_LAUNCHER=ccache") if args.external_graph_transformer_path: cmake_args.append("-Donnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH=" + args.external_graph_transformer_path) @@ -728,9 +718,9 @@ def generate_build_tree( cmake_args += ["-Donnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD=ON"] if args.use_migraphx: cmake_args.append("-Donnxruntime_MIGRAPHX_HOME=" + migraphx_home) - if args.use_rocm: cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home) cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) + if args.use_tensorrt or args.use_nv_tensorrt_rtx: cmake_args.append("-Donnxruntime_TENSORRT_HOME=" + tensorrt_home) @@ -1899,7 +1889,6 @@ def build_python_wheel( default_training_package_device=False, use_ninja=False, enable_training_apis=False, - enable_rocm_profiling=False, ): for config in configs: cwd = get_config_build_dir(build_dir, config) @@ -1919,8 +1908,6 @@ def build_python_wheel( args.append("--enable_training") if enable_training_apis: args.append("--enable_training_apis") - if enable_rocm_profiling: - args.append("--enable_rocm_profiling") # The following arguments are mutually exclusive if use_cuda: @@ -2579,7 +2566,6 @@ def main(): default_training_package_device=default_training_package_device, use_ninja=(args.cmake_generator == "Ninja"), enable_training_apis=args.enable_training_apis, - enable_rocm_profiling=args.enable_rocm_profiling, ) if args.build_nuget: diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index 561eab7f2d61d..22f9cc054006e 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -346,7 +346,6 @@ def add_webassembly_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--enable_wasm_simd", action="store_true", help="Enable WebAssembly SIMD.") parser.add_argument("--enable_wasm_relaxed_simd", action="store_true", help="Enable WebAssembly Relaxed SIMD.") parser.add_argument("--enable_wasm_threads", action="store_true", help="Enable WebAssembly multi-threading.") - parser.add_argument("--enable_wasm_memory64", action="store_true", help="Enable WebAssembly 64-bit memory.") parser.add_argument("--disable_wasm_exception_catching", action="store_true", help="Disable exception catching.") parser.add_argument( "--enable_wasm_api_exception_catching", @@ -604,18 +603,6 @@ def add_execution_provider_args(parser: argparse.ArgumentParser) -> None: help="Enable CUDA kernel profiling (requires CUPTI in PATH).", ) - # --- ROCm --- - rocm_group = parser.add_argument_group("ROCm Execution Provider") - rocm_group.add_argument("--use_rocm", action="store_true", help="Enable ROCm EP.") - rocm_group.add_argument("--rocm_version", help="ROCm stack version.") - rocm_group.add_argument("--rocm_home", help="Path to ROCm installation directory.") - # ROCm-specific profiling - rocm_group.add_argument( - "--enable_rocm_profiling", - action="store_true", - help="Enable ROCm kernel profiling.", - ) - # --- DNNL (formerly MKL-DNN / oneDNN) --- dnnl_group = parser.add_argument_group("DNNL Execution Provider") dnnl_group.add_argument("--use_dnnl", action="store_true", help="Enable DNNL EP.") @@ -731,6 +718,9 @@ def add_execution_provider_args(parser: argparse.ArgumentParser) -> None: migx_group = parser.add_argument_group("MIGraphX Execution Provider") migx_group.add_argument("--use_migraphx", action="store_true", help="Enable MIGraphX EP.") migx_group.add_argument("--migraphx_home", help="Path to MIGraphX installation directory.") + migx_group.add_argument("--use_rocm", action="store_true", help="Enable ROCm EP.") + migx_group.add_argument("--rocm_version", help="ROCm stack version.") + migx_group.add_argument("--rocm_home", help="Path to ROCm installation directory.") # --- WebNN --- webnn_group = parser.add_argument_group("WebNN Execution Provider") diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index b925619401c27..d3bd9c79afe08 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -56,7 +56,7 @@ extends: msbuildPlatform: x64 packageName: x64-cuda CudaVersion: ${{ parameters.CudaVersion }} - buildparameter: --use_cuda --cuda_home=${{ variables.win_cuda_home }} --enable_onnx_tests --enable_wcos --use_webgpu --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90-virtual" + buildparameter: --use_cuda --cuda_home=${{ variables.win_cuda_home }} --enable_onnx_tests --use_webgpu --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90-virtual" runTests: false buildJava: false java_artifact_id: onnxruntime_gpu diff --git a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml deleted file mode 100644 index c6ebb80f98e12..0000000000000 --- a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml +++ /dev/null @@ -1,187 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -name: 'linux_ci_$(Date:yyyyMMdd)_$(Rev:r)' - -# gid of video and render group on gcramdrr1-mi100-085 and -86 -variables: - - name: video - value: 44 - - name: render - value: 109 - - name: RocmVersion - value: 6.4 - -jobs: -- job: Linux_Build - variables: - skipComponentGovernanceDetection: true - CCACHE_DIR: $(Pipeline.Workspace)/ccache - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - workspace: - clean: all - pool: onnxruntime-Ubuntu2204-AMD-CPU - timeoutInMinutes: 240 - - steps: - - - checkout: self - clean: true - submodules: recursive - - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)" - Repository: onnxruntimetrainingmigraphx-cibuild-rocm$(RocmVersion) - - - task: Cache@2 - inputs: - key: '"$(TODAY)" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' - path: $(CCACHE_DIR) - cacheHitVar: CACHE_RESTORED - restoreKeys: | - "$(TODAY)" | "$(Build.SourceBranch)" - "$(TODAY)" | - displayName: Cache Task - - - script: mkdir -p $(CCACHE_DIR) - condition: ne(variables.CACHE_RESTORED, 'true') - displayName: Create Cache Dir - - - task: CmdLine@2 - inputs: - script: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --user $UID:$(id -g $USER) \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(CCACHE_DIR):/cache \ - -e CCACHE_DIR=/cache \ - --workdir /onnxruntime_src \ - onnxruntimetrainingmigraphx-cibuild-rocm$(RocmVersion) \ - /bin/bash -c " - set -ex; \ - env; \ - ccache -s; \ - python tools/ci_build/build.py \ - --config Release \ - --enable_training \ - --cmake_extra_defines \ - CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ \ - onnxruntime_BUILD_KERNEL_EXPLORER=OFF \ - onnxruntime_USE_COMPOSABLE_KERNEL=OFF \ - --use_migraphx \ - --rocm_version=$(RocmVersion) \ - --rocm_home /opt/rocm \ - --nccl_home /opt/rocm \ - --enable_nccl \ - --update \ - --build_dir /build \ - --build \ - --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache \ - --build_wheel \ - --skip_submodule_sync \ - --use_cache \ - --skip_tests --cmake_path /usr/bin/cmake --ctest_path /usr/bin/ctest; \ - ccache -sv; \ - ccache -z" - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Build onnxruntime' - - - task: CmdLine@2 - inputs: - script: | - cd $(Build.BinariesDirectory)/Release - find -executable -type f > $(Build.BinariesDirectory)/Release/perms.txt - displayName: 'Find Executable Files' - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact' - inputs: - artifactName: 'drop-linux' - targetPath: '$(Build.BinariesDirectory)/Release' - - - template: templates/explicitly-defined-final-tasks.yml - -- job: Linux_Test - workspace: - clean: all - pool: AMD-GPU - dependsOn: - - Linux_Build - timeoutInMinutes: 120 - - steps: - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - buildType: 'current' - artifactName: 'drop-linux' - targetPath: '$(Build.BinariesDirectory)/Release' - - - checkout: self - clean: true - submodules: recursive - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)" - Repository: onnxruntimetrainingmigraphx-cibuild-rocm$(RocmVersion) - - - task: CmdLine@2 - inputs: - script: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --device=/dev/kfd \ - --device=/dev/dri/renderD$DRIVER_RENDER \ - --group-add $(video) \ - --group-add $(render) \ - --user $UID:$(id -g $USER) \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --workdir /build/Release \ - onnxruntimetrainingmigraphx-cibuild-rocm$(RocmVersion) \ - /bin/bash -c " - set -ex; \ - cd /build/Release && xargs -a /build/Release/perms.txt chmod a+x; \ - bash /onnxruntime_src/tools/ci_build/github/pai/pai_test_launcher.sh" - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Run onnxruntime unit tests' - - - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml deleted file mode 100644 index 7388ed6d5a1e9..0000000000000 --- a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml +++ /dev/null @@ -1,237 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -name: 'linux_ci_$(Date:yyyyMMdd)_$(Rev:r)' - -# gid of video and render group on gcramdrr1-mi100-085 and -86 -variables: - - name: video - value: 44 - - name: render - value: 109 - - name: RocmVersion - value: 6.4 - -jobs: -- job: Linux_Build - variables: - skipComponentGovernanceDetection: true - CCACHE_DIR: $(Pipeline.Workspace)/ccache - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - workspace: - clean: all - pool: onnxruntime-Ubuntu2204-AMD-CPU - timeoutInMinutes: 240 - - steps: - - - checkout: self - clean: true - submodules: recursive - - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion) --build-arg BUILD_UID=1004" - Repository: onnxruntimerocm-cibuild-rocm$(RocmVersion)new - - - task: Cache@2 - inputs: - key: '"$(TODAY)" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' - path: $(CCACHE_DIR) - cacheHitVar: CACHE_RESTORED - restoreKeys: | - "$(TODAY)" | "$(Build.SourceBranch)" - "$(TODAY)" | - displayName: Cache Task - - - script: mkdir -p $(CCACHE_DIR) - condition: ne(variables.CACHE_RESTORED, 'true') - displayName: Create Cache Dir - - - task: CmdLine@2 - inputs: - script: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --user $UID:$(id -g $USER) \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(CCACHE_DIR):/cache \ - -e CCACHE_DIR=/cache \ - --workdir /onnxruntime_src \ - onnxruntimerocm-cibuild-rocm$(RocmVersion)new \ - /bin/bash -c " - set -ex; \ - env; \ - ccache -s; \ - python tools/ci_build/build.py \ - --config Release \ - --cmake_extra_defines \ - CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ \ - onnxruntime_BUILD_KERNEL_EXPLORER=ON \ - CMAKE_HIP_ARCHITECTURES=gfx90a \ - --use_rocm \ - --rocm_version=$(RocmVersion) \ - --rocm_home /opt/rocm \ - --nccl_home /opt/rocm \ - --enable_nccl \ - --update \ - --build_dir /build \ - --build \ - --build_shared_lib \ - --parallel \ - --build_wheel \ - --enable_onnx_tests \ - --skip_submodule_sync \ - --use_cache \ - --skip_tests --cmake_path /usr/bin/cmake --ctest_path /usr/bin/ctest; \ - ccache -sv; \ - ccache -z" - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Build onnxruntime' - - - task: CmdLine@2 - inputs: - script: | - cd $(Build.BinariesDirectory)/Release - find -executable -type f > $(Build.BinariesDirectory)/Release/perms.txt - displayName: 'Find Executable Files' - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact' - inputs: - artifactName: 'drop-linux' - targetPath: '$(Build.BinariesDirectory)/Release' - - - template: templates/explicitly-defined-final-tasks.yml - -- job: Linux_Test - workspace: - clean: all - pool: AMD-GPU - dependsOn: - - Linux_Build - timeoutInMinutes: 120 - - steps: - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - buildType: 'current' - artifactName: 'drop-linux' - targetPath: '$(Build.BinariesDirectory)/Release' - - - checkout: self - clean: true - submodules: recursive - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion) --build-arg BUILD_UID=1004" - Repository: onnxruntimerocm-cibuild-rocm$(RocmVersion)new - - - task: CmdLine@2 - inputs: - script: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --device=/dev/kfd \ - --device=/dev/dri/renderD$DRIVER_RENDER \ - --group-add $(video) \ - --group-add $(render) \ - --user $UID:$(id -g $USER) \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume /data/models:/build/models:ro \ - --workdir /build/Release \ - onnxruntimerocm-cibuild-rocm$(RocmVersion)new \ - /bin/bash -c " - set -ex; \ - xargs -a /build/Release/perms.txt chmod a+x; \ - python /onnxruntime_src/tools/ci_build/build.py \ - --config Release \ - --cmake_extra_defines \ - CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ \ - onnxruntime_BUILD_KERNEL_EXPLORER=ON \ - CMAKE_HIP_ARCHITECTURES=gfx90a \ - --mpi_home /opt/ompi \ - --use_rocm \ - --rocm_version=$(RocmVersion) \ - --rocm_home /opt/rocm \ - --nccl_home /opt/rocm \ - --enable_nccl \ - --build_dir /build \ - --build_shared_lib \ - --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache \ - --build_wheel \ - --skip_submodule_sync \ - --test --enable_onnx_tests --enable_transformers_tool_test \ - --cmake_path /usr/bin/cmake --ctest_path /usr/bin/ctest" - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Run onnxruntime unit tests' - - - task: CmdLine@2 - inputs: - script: |- - docker run -e SYSTEM_COLLECTIONURI --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --device=/dev/kfd \ - --device=/dev/dri/renderD$DRIVER_RENDER \ - --group-add $(video) \ - --group-add $(render) \ - --user $UID:$(id -g $USER) \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - -e OPENBLAS_NUM_THREADS=1 \ - -e OPENMP_NUM_THREADS=1 \ - -e MKL_NUM_THREADS=1 \ - -e KERNEL_EXPLORER_BUILD_DIR=/build/Release \ - -e KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE=8 \ - -e KERNEL_EXPLORER_TEST_USE_CUPY=0 \ - -e CUPY_CACHE_DIR=/build/Release \ - onnxruntimerocm-cibuild-rocm$(RocmVersion)new \ - /bin/bash -c " - set -ex; \ - python --version; id ; ls -lha /home ; \ - ls /opt/miniconda/envs/rocm-ci/lib/; \ - pytest /onnxruntime_src/onnxruntime/python/tools/kernel_explorer/ -n 4 --reruns 1 --durations=100" - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Run kernel explorer tests' - condition: succeededOrFailed() - - - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/rocm.yml b/tools/ci_build/github/azure-pipelines/templates/rocm.yml deleted file mode 100644 index 25d58ffe0b29e..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/rocm.yml +++ /dev/null @@ -1,145 +0,0 @@ -parameters: -- name: PythonVersion - type: string - -- name: RocmVersion - type: string - -- name: RocmVersionPatchSuffix - type: string - default: '' - -- name: BuildConfig - type: string - default: 'Release' - -jobs: -- job: wheels_python_${{ replace(parameters.PythonVersion,'.','_') }}_rocm_${{ replace(parameters.RocmVersion,'.','_') }}_${{ parameters.BuildConfig }} - workspace: - clean: all - timeoutInMinutes: 360 - pool: Ubuntu-2204-rocm-aiinfra - variables: - - name: PythonVersion - value: ${{ parameters.PythonVersion }} - - name: EnableProfiling - ${{ if eq(parameters.BuildConfig, 'Release') }}: - value: '' - ${{ else }}: - value: '--enable_rocm_profiling' - - name: ArtifactName - ${{ if eq(parameters.BuildConfig, 'Release') }}: - value: 'onnxruntime_rocm' - ${{ else }}: - value: 'onnxruntime_rocm_Debug' - - steps: - - task: CmdLine@2 - displayName: 'check variables' - inputs: - script: | - echo "BuildConfig is "${{ parameters.BuildConfig }} && \ - echo "EnableProfiling is "${{ variables['EnableProfiling'] }} && \ - echo "ArtifactName is "${{ variables['ArtifactName'] }} - - - checkout: self - clean: true - submodules: recursive - - - template: set-python-manylinux-variables-step.yml - - template: get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm - Context: tools/ci_build/github/linux/docker - CheckOutManyLinux: true - DockerBuildArgs: >- - --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur - --build-arg BUILD_UID=$(id -u) - --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 - --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root - --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: - --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib - --build-arg ROCM_VERSION=${{ parameters.RocmVersion }}${{ parameters.RocmVersionPatchSuffix }} - Repository: onnxruntimetrainingrocmbuild-rocm${{ parameters.RocmVersion }} - - - task: CmdLine@2 - inputs: - script: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --privileged \ - --ipc=host \ - --network=host \ - --cap-add=SYS_PTRACE \ - --security-opt seccomp=unconfined \ - -e CC=/opt/rh/gcc-toolset-12/root/usr/bin/cc -e CXX=/opt/rh/gcc-toolset-12/root/usr/bin/c++ -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --workdir /onnxruntime_src \ - --entrypoint $(PythonManylinuxDir)/bin/python3 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - -e ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION \ - --user onnxruntimedev \ - onnxruntimetrainingrocmbuild-rocm${{ parameters.RocmVersion }} \ - /onnxruntime_src/tools/ci_build/build.py \ - --config ${{ parameters.BuildConfig }} \ - --use_rocm \ - --use_migraphx \ - --rocm_version=${{ parameters.RocmVersion }} \ - --rocm_home=/opt/rocm \ - --nccl_home=/opt/rocm \ - --update \ - --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache \ - --build_dir /build \ - --build \ - --build_wheel \ - --skip_tests \ - --enable_training \ - --cmake_extra_defines \ - CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ \ - onnxruntime_BUILD_UNIT_TESTS=OFF FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER \ - ${{ variables['EnableProfiling'] }} - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Build onnxruntime (in container)' - - # All UTs were here are now covered in AMD CI - see orttraining-pai-ci-pipeline.yml - # This CI is mainly responsible for packaging. The uploaded whl could be used in the downstream CIs (if any). - # For example, docker image build (e.g., PTCA), reporting CI, etc. to further verify the whl as needed. - # To view the UTs disabled from this CI - see https://github.com/microsoft/onnxruntime/pull/11945 for examples - - - script: |- - # Do not output ##vso[] commands with `set -x` or they may be parsed again and include a trailing quote. - set +x - echo "Tests will run using HIP_VISIBLES_DEVICES=$HIP_VISIBLE_DEVICES" - video_gid=$(getent group | awk '/video/ {split($0,a,":"); print(a[3])}') - echo "Found video_gid=$video_gid; attempting to set as pipeline variable" - echo "##vso[task.setvariable variable=video]$video_gid" - render_gid=$(getent group | awk '/render/ {split($0,a,":"); print(a[3])}') - echo "Found render_gid=$render_gid; attempting to set as pipeline variable" - echo "##vso[task.setvariable variable=render]$render_gid" - condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) - displayName: 'Find video and render gid to be mapped into container' - - - task: CopyFiles@2 - displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)' - Contents: "${{ parameters.BuildConfig }}/dist/*.whl" - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: PublishBuildArtifacts@1 - displayName: 'Upload Rocm wheel as build artifact' - inputs: - ArtifactName: ${{ variables['ArtifactName'] }} - - - script: | - files=($(Build.ArtifactStagingDirectory)/${{ parameters.BuildConfig }}/dist/*.whl) && \ - echo ${files[0]} && \ - python3 tools/ci_build/upload_python_package_to_azure_storage.py \ - --python_wheel_path ${files[0]} \ - --final_storage - condition: and(ne(variables['ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION'], 'true'), and(succeeded(), eq(variables['DRY_RUN'], '0'))) - displayName: 'Upload Rocm wheel to release repository' - - - - template: clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu new file mode 100644 index 0000000000000..8749502461ac5 --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu @@ -0,0 +1,21 @@ +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14:20250124.1 + +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 + +#Install Vulkan and RPM Fusion for NVIDIA drivers (UBI8/RHEL8) +RUN dnf install -y \ + https://download1.rpmfusion.org/free/el/rpmfusion-free-release-8.noarch.rpm \ + https://download1.rpmfusion.org/nonfree/el/rpmfusion-nonfree-release-8.noarch.rpm && \ + dnf install -y xorg-x11-drv-nvidia akmod-nvidia vulkan vulkan-tools mesa-vulkan-drivers + +# TODO: Currently this Dockerfile works only for building WebGPU. Need to make it also support running WebGPU tests. + +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh +RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts + +ARG BUILD_UID=1001 +ARG BUILD_USER=onnxruntimedev +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile deleted file mode 100644 index 83a4e04435b95..0000000000000 --- a/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile +++ /dev/null @@ -1,104 +0,0 @@ -# Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete -FROM ubuntu:22.04 - -ARG ROCM_VERSION=6.4 -ARG AMDGPU_VERSION=${ROCM_VERSION} -ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' - -CMD ["/bin/bash"] - -RUN echo "$APT_PREF" > /etc/apt/preferences.d/rocm-pin-600 - -ENV DEBIAN_FRONTEND noninteractive - -RUN apt-get update && \ - apt-get install -y --no-install-recommends ca-certificates curl libnuma-dev gnupg && \ - curl -sL https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - &&\ - printf "deb [arch=amd64] https://repo.radeon.com/rocm/apt/$ROCM_VERSION/ jammy main" | tee /etc/apt/sources.list.d/rocm.list && \ - printf "deb [arch=amd64] https://repo.radeon.com/amdgpu/$AMDGPU_VERSION/ubuntu jammy main" | tee /etc/apt/sources.list.d/amdgpu.list && \ - apt-get update && apt-get install -y --no-install-recommends \ - sudo git \ - libelf1 \ - kmod \ - file zip unzip \ - python3 \ - python3-pip \ - rocm-dev \ - rocm-libs \ - build-essential && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* - -RUN groupadd -g 109 render - -# Upgrade to meet security requirements -RUN apt-get update -y && apt-get upgrade -y && apt-get autoremove -y && \ - apt-get install -y locales cifs-utils wget half libnuma-dev lsb-release && \ - apt-get clean -y - -RUN locale-gen en_US.UTF-8 -RUN update-locale LANG=en_US.UTF-8 -ENV LC_ALL C.UTF-8 -ENV LANG C.UTF-8 - -WORKDIR /stage - -# Cmake -ENV CMAKE_VERSION=3.31.5 -RUN cd /usr/local && \ - wget -q https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-Linux-x86_64.tar.gz && \ - tar -zxf /usr/local/cmake-3.31.5-Linux-x86_64.tar.gz --strip=1 -C /usr - -# Install Ninja -COPY scripts/install-ninja.sh /build_scripts/ -RUN /bin/bash /build_scripts/install-ninja.sh - -# Install VCPKG -ENV VCPKG_INSTALLATION_ROOT=/usr/local/share/vcpkg -ENV VCPKG_FORCE_SYSTEM_BINARIES=ON -COPY scripts/install-vcpkg.sh /build_scripts/ -RUN /bin/bash /build_scripts/install-vcpkg.sh - -# ccache -RUN mkdir -p /tmp/ccache && \ - cd /tmp/ccache && \ - wget -q -O - https://github.com/ccache/ccache/releases/download/v4.7.4/ccache-4.7.4-linux-x86_64.tar.xz | tar --strip 1 -J -xf - && \ - cp /tmp/ccache/ccache /usr/bin && \ - rm -rf /tmp/ccache - -# Install Conda -ENV PATH /opt/miniconda/bin:${PATH} -RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh --no-check-certificate && /bin/bash ~/miniconda.sh -b -p /opt/miniconda && \ - conda init bash && \ - conda config --set auto_activate_base false && \ - conda update --all && \ - rm ~/miniconda.sh && conda clean -ya - -# Create rocm-ci environment -ENV CONDA_ENVIRONMENT_PATH /opt/miniconda/envs/rocm-ci -ENV CONDA_DEFAULT_ENV rocm-ci -RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.10 -ENV PATH ${CONDA_ENVIRONMENT_PATH}/bin:${PATH} - -# Enable rocm-ci environment -SHELL ["conda", "run", "-n", "rocm-ci", "/bin/bash", "-c"] - -# Some DLLs in the conda environment have conflict with the one installed in Ubuntu system. -# For example, the GCC version in the conda environment is 12.x, while the one in the Ubuntu 22.04 is 11.x. -# ln -sf to make sure we always use libstdc++.so.6 and libgcc_s.so.1 in the system. -RUN ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_ENVIRONMENT_PATH}/bin/../lib/libstdc++.so.6 -RUN ln -sf /usr/lib/x86_64-linux-gnu/libgcc_s.so.1 ${CONDA_ENVIRONMENT_PATH}/bin/../lib/libgcc_s.so.1 - -RUN pip install packaging \ - ml_dtypes==0.5.0 \ - pytest==7.4.4 \ - pytest-xdist \ - pytest-rerunfailures \ - scipy==1.14.1 \ - numpy==1.26.4 - -ARG BUILD_UID=1000 -ARG BUILD_USER=onnxruntimedev -RUN adduser --gecos 'onnxruntime Build User' --disabled-password $BUILD_USER --uid $BUILD_UID -USER $BUILD_USER -WORKDIR /home/$BUILD_USER diff --git a/tools/ci_build/requirements/transformers-test/requirements.txt b/tools/ci_build/requirements/transformers-test/requirements.txt index 84a86dea6adcd..223f1f46f2e70 100644 --- a/tools/ci_build/requirements/transformers-test/requirements.txt +++ b/tools/ci_build/requirements/transformers-test/requirements.txt @@ -4,7 +4,7 @@ packaging protobuf==4.21.12 numpy==1.21.6 ; python_version < '3.9' numpy==2.0.0 ; python_version >= '3.9' -torch +torch>=2.6.0 coloredlogs==15.0 transformers==4.48.0 parameterized>=0.8.1