diff --git a/.github/labeler.yml b/.github/labeler.yml index c14e2a213bc60..21ca6769d491c 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -11,7 +11,6 @@ ep:oneDNN: '/\bone\s*dnn\b/i' ep:OpenVINO: '/\bopen\s*vino\b/i' ep:QNN: '/\bqnn\b/i' ep:RockchipNPU: '/\brockchip(?:npu)?\b/i' -ep:ROCm: '/\brocm\b/i' ep:SNPE: '/\bsnpe\b/i' ep:tvm: '/\btvm\b/i' ep:VitisAI: '/\bvitis(?:ai)?\b/i' diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 7f7ff74959d52..732fc69a604f9 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -24,10 +24,14 @@ permissions: jobs: AndroidBinarySizeCheckJob_MinimalBaseline: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=AndroidBinarySizeCheckJob_MinimalBaseline-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -66,7 +70,7 @@ jobs: set_var("BuildConfigOs", config["os"]) shell: python working-directory: ${{ github.workspace }} - + - name: 1a. Build onnxruntime run: | set -e -x @@ -110,9 +114,13 @@ jobs: shell: bash android_nnapi_ep: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=android_nnapi_ep-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Use jdk 17 uses: actions/setup-java@v5 @@ -185,9 +193,14 @@ jobs: android_cpu_ep: name: Android CI Pipeline - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=android_cpu_ep-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] + steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Use jdk 17 uses: actions/setup-java@v5 diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index 30f832f67c5ee..d02fbfa018473 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -9,10 +9,14 @@ on: jobs: validate: name: "validate" - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=cffconvert-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - name: Check out a copy of the repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Check whether the citation metadata from CITATION.cff is valid uses: citation-file-format/cffconvert-github-action@2.0.0 diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index d33e4d923a0bc..67a0d6c573f4d 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -23,7 +23,11 @@ concurrency: jobs: analyze: name: Analyze - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=codeql-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: actions: read contents: read @@ -38,7 +42,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 04177b11e9c30..cadddd0d7653d 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -13,9 +13,13 @@ on: jobs: validation: name: "Validation" - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=gradle-wrapper-validation-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: gradle/actions/wrapper-validation@v5 concurrency: group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} diff --git a/.github/workflows/ios.yml b/.github/workflows/ios.yml index 0d2046b980783..ed572aa339ce9 100644 --- a/.github/workflows/ios.yml +++ b/.github/workflows/ios.yml @@ -20,7 +20,7 @@ jobs: runs-on: macos-14 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 00960c848b107..3ede5919feedf 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -8,7 +8,11 @@ permissions: jobs: triage: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=labeler-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - uses: github/issue-labeler@v3.4 with: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5aaab5f8e1a10..dc8e90a83a03a 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: name: Optional Lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: misspell # Check spellings as well uses: reviewdog/action-misspell@v1 with: @@ -37,12 +37,16 @@ jobs: lint-python-format: # Required workflow name: Python format - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=lint-python-format-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: contents: read security-events: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 with: @@ -87,7 +91,7 @@ jobs: name: Optional Lint C++ runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Update PATH run: | echo "$HOME/.local/bin" >> "$GITHUB_PATH" @@ -114,9 +118,13 @@ jobs: lint-js: name: Lint JavaScript - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=lint-js-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-node@v6 with: node-version: 20 diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index 2370c631b7a7a..f7830d1c97114 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -34,7 +34,11 @@ on: jobs: build-wasm: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=build-wasm-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] env: buildArch: x64 common_build_args: >- @@ -49,7 +53,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: recursive diff --git a/.github/workflows/linux_ci.yml b/.github/workflows/linux_ci.yml index 6f517f2656e94..9aa8418c55a40 100644 --- a/.github/workflows/linux_ci.yml +++ b/.github/workflows/linux_ci.yml @@ -48,6 +48,7 @@ jobs: dockerfile_path: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile docker_image_repo: onnxruntimecpubuildcix64 extra_build_flags: '--enable_address_sanitizer' + job_identifier: build-linux-x64-debug # python_path_prefix: '' # Default empty string is fine, no prefix needed secrets: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -63,6 +64,7 @@ jobs: docker_image_repo: onnxruntimecpubuildpythonx64 extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --build_nuget --enable_transformers_tool_test --cmake_extra_defines onnxruntime_BUILD_BENCHMARKS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' # $ needs escaping in single quotes + job_identifier: build-linux-x64-release secrets: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -77,6 +79,7 @@ jobs: docker_image_repo: onnxruntimecpubuildpythonx64 # Shares image with standard x64 release extra_build_flags: '--enable_training --use_binskim_compliant_compile_flags --build_wheel --build_nuget --enable_transformers_tool_test --cmake_extra_defines onnxruntime_BUILD_BENCHMARKS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' # $ needs escaping in single quotes + job_identifier: orttraining-linux-ci-pipeline secrets: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -92,6 +95,7 @@ jobs: docker_image_repo: onnxruntimecpubuildciaarch64 # ASan disabled due to excessive runtime (>4hr). Includes wheel build for basic checks. extra_build_flags: '--use_binskim_compliant_compile_flags --build_shared_lib' + job_identifier: build-linux-arm64-debug secrets: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -106,5 +110,6 @@ jobs: docker_image_repo: onnxruntimecpubuildpythonaarch64 extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cmake_extra_defines onnxruntime_BUILD_BENCHMARKS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' # $ needs escaping in single quotes + job_identifier: build-linux-arm64-release secrets: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index 886705471b7de..3075d5cbba09b 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -34,6 +34,7 @@ jobs: run_tests: false # <<< Do not run tests in this job upload_build_output: true # <<< Upload the build/Release directory execution_providers: 'cuda' + job_identifier: build-linux-cuda-x64-release secrets: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Pass token for reusable workflow needs (e.g., docker build action) @@ -43,12 +44,13 @@ jobs: runs-on: - self-hosted - "1ES.Pool=Onnxruntime-github-Linux-GPU-H100" + - "JobId=test-linux-cuda-x64-release-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" permissions: contents: read packages: read steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index af86975ee6cdc..3246b4e07c2f3 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -22,13 +22,17 @@ jobs: # Job 1: Build full onnxruntime and generate ORT format test files build_full_ort: name: 1. Build Full ORT and Generate ORT Files - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=build_full_ort-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: contents: read packages: write steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -58,14 +62,18 @@ jobs: # Job 2: Build minimal onnxruntime [exceptions DISABLED, type reduction DISABLED, training ops ENABLED] build_minimal_exceptions_disabled: name: 2. Build Minimal (Exceptions Disabled) - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=build_minimal_exceptions_disabled-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: # Permissions needed for build-docker-image contents: read packages: write id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -115,14 +123,18 @@ jobs: build_minimal_custom_ops: name: 3a. Build Minimal (Custom Ops) needs: build_full_ort # Depends on Job 1 for test data - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=build_minimal_custom_ops-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: # Permissions needed for build-docker-image contents: read packages: write id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -149,14 +161,18 @@ jobs: build_minimal_type_reduction: name: 3b. Build Minimal (Type Reduction) needs: build_full_ort # Depends on Job 1 for test data - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=build_minimal_type_reduction-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: # Permissions needed for build-docker-image contents: read packages: write id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -181,14 +197,18 @@ jobs: # Job 4: Build minimal onnxruntime [exceptions ENABLED, type reduction ENABLED (globally allowed types)] and run tests build_minimal_globally_allowed_types: name: 4. Build Minimal (Globally Allowed Types) - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=build_minimal_globally_allowed_types-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: # Permissions needed for build-docker-image contents: read packages: write id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -215,14 +235,18 @@ jobs: # Job 5: Build extended minimal onnxruntime and run tests build_extended_minimal: name: 5. Build Extended Minimal - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=build_extended_minimal-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: # Permissions needed for build-docker-image contents: read packages: write id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -279,14 +303,18 @@ jobs: # Job 6a: Regular build with python and all optional features disabled. build_regular_no_optional: name: 6a. Build Regular (No Optional Features) - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=build_regular_no_optional-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: # Permissions needed for build-docker-image contents: read packages: write id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -356,14 +384,18 @@ jobs: # Job 6b: Minimal build with all optional features disabled. build_minimal_no_optional: name: 6b. Build Minimal (No Optional Features) - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=build_minimal_no_optional-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: # Permissions needed for build-docker-image contents: read packages: write id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -423,14 +455,18 @@ jobs: # Job 6c: Extended minimal build with all optional features disabled. build_extended_minimal_no_optional: name: 6c. Build Extended Minimal (No Optional Features) - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=build_extended_minimal_no_optional-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: # Permissions needed for build-docker-image contents: read packages: write id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -498,14 +534,18 @@ jobs: build_extended_minimal_android: name: 7. Build Extended Minimal (Android NNAPI) needs: build_full_ort # Depends on Job 1 for test data - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=build_extended_minimal_android-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: # Permissions needed for build-docker-image contents: read packages: write id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 diff --git a/.github/workflows/linux_openvino_ci.yml b/.github/workflows/linux_openvino_ci.yml index 0a4827087309e..e04bd6fd7d7af 100644 --- a/.github/workflows/linux_openvino_ci.yml +++ b/.github/workflows/linux_openvino_ci.yml @@ -39,6 +39,7 @@ jobs: run_tests: true upload_build_output: false + job_identifier: build_test_openvino # Secrets: Pass the necessary GitHub token secrets: diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index 0e26576829e94..1d3a49909a634 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -34,6 +34,7 @@ jobs: run_tests: false # <<< Do not run tests in this job upload_build_output: true # <<< Upload the build/Release directory execution_providers: 'cuda tensorrt' + job_identifier: build-linux-TensorRT-x64-release secrets: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Pass token for reusable workflow needs (e.g., docker build action) @@ -43,12 +44,13 @@ jobs: runs-on: - self-hosted - "1ES.Pool=Onnxruntime-github-Linux-GPU-H100" + - "JobId=test-linux-TensorRT-x64-release-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" permissions: contents: read packages: read steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 # --- Build the Docker image needed for testing --- - name: Build Docker Image for Testing diff --git a/.github/workflows/linux_webgpu.yml b/.github/workflows/linux_webgpu.yml index f7161754895c5..28cf9fbf8f6ff 100644 --- a/.github/workflows/linux_webgpu.yml +++ b/.github/workflows/linux_webgpu.yml @@ -33,6 +33,7 @@ jobs: run_tests: false upload_build_output: true execution_providers: 'webgpu' + job_identifier: build-linux-webgpu-x64-release secrets: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Pass token for reusable workflow needs (e.g., docker build action) diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index e545406d8d20f..8ba87bc1f731c 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -76,7 +76,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' @@ -124,7 +124,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/macos-ci-build-and-test-workflow.yml b/.github/workflows/macos-ci-build-and-test-workflow.yml index 329584c68d7d1..8e1d0264496f6 100644 --- a/.github/workflows/macos-ci-build-and-test-workflow.yml +++ b/.github/workflows/macos-ci-build-and-test-workflow.yml @@ -75,7 +75,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index abe627f4ff7bc..e4fb442869616 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -19,12 +19,16 @@ concurrency: jobs: auto-apply-fixes: name: Suggest fixes - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=pr-auto-apply-fixes-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] permissions: contents: read pull-requests: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 with: diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index 25b7899584bbf..749d587283595 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -22,9 +22,13 @@ permissions: jobs: build: name: Generate C/C++ API docs - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=apidocs-c-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install doxygen and dependencies run: | sudo apt update diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index 34b9c1af9552f..35cc5bab27e28 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -20,11 +20,15 @@ permissions: jobs: build: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-vs2022-latest", + "JobId=publish-csharp-apidocs-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] env: DOCFXVERSION: 2.62.2 steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install DocFX run: | dotnet tool update -g docfx diff --git a/.github/workflows/publish-gh-pages.yml b/.github/workflows/publish-gh-pages.yml index 11745ce24f9e5..d8f53f8dd698b 100644 --- a/.github/workflows/publish-gh-pages.yml +++ b/.github/workflows/publish-gh-pages.yml @@ -8,7 +8,11 @@ on: jobs: placeholder: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=publish-gh-pages-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - name: Placeholder step to have workflow included in the GitHub web UI run: | diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 656d0627ed17d..b3b831944c999 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -21,9 +21,13 @@ permissions: jobs: build: name: Generate Java docs - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=apidocs-java-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Set up JDK 11 uses: actions/setup-java@v5 with: diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index e71d3b3c57a4b..df507462a7d59 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -21,9 +21,13 @@ permissions: jobs: build: name: Generate JS API docs - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=apidocs-js-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Node.js uses: actions/setup-node@v6 with: diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index 983d3d478a49d..a73b62eba6050 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Objective-C API docs runs-on: macos-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml index 389d1683fb1ff..055f2909593ea 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github/workflows/publish-python-apidocs.yml @@ -22,9 +22,13 @@ permissions: jobs: build: name: Generate Python API docs - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=apidocs-python-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install tools run: | sudo apt-get update diff --git a/.github/workflows/react_native.yml b/.github/workflows/react_native.yml index 343186b1aec8c..f0f0dbfb410f6 100644 --- a/.github/workflows/react_native.yml +++ b/.github/workflows/react_native.yml @@ -14,13 +14,17 @@ concurrency: jobs: build_android_packages: name: Build Android AAR Packages - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=react_build_android_packages-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 120 outputs: aar_path: ${{ runner.temp }}/.artifacts steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -68,14 +72,18 @@ jobs: react_native_ci_android: name: React Native CI Android needs: build_android_packages - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=react_native_ci_android-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 90 steps: - name: Set ANDROID_AVD_HOME environment variable run: echo "ANDROID_AVD_HOME=${{ runner.temp }}/android-avd" >> $GITHUB_ENV - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Use Python 3.12 uses: actions/setup-python@v6 @@ -175,7 +183,7 @@ jobs: timeout-minutes: 120 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Use Xcode 15.3.0 run: sudo xcode-select --switch /Applications/Xcode_15.3.0.app/Contents/Developer @@ -218,7 +226,7 @@ jobs: timeout-minutes: 90 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Download iOS pod artifact uses: actions/download-artifact@v6 diff --git a/.github/workflows/reusable_linux_build.yml b/.github/workflows/reusable_linux_build.yml index 795e35b06bfb0..f8b2170df9939 100644 --- a/.github/workflows/reusable_linux_build.yml +++ b/.github/workflows/reusable_linux_build.yml @@ -58,6 +58,10 @@ on: required: false type: boolean default: false + job_identifier: + description: 'A unique identifier for the job, used for hosted pool tracking' + required: true + type: string secrets: GH_TOKEN: description: 'GitHub token for accessing actions/packages' @@ -68,6 +72,7 @@ jobs: runs-on: - self-hosted - "1ES.Pool=${{ inputs.pool_name }}" + - "JobId=${{ inputs.job_identifier }}-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" permissions: contents: read packages: write @@ -75,7 +80,7 @@ jobs: id-token: write steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Set up Python ${{ inputs.python_version }} if: inputs.architecture != 'arm64' diff --git a/.github/workflows/title-only-labeler.yml b/.github/workflows/title-only-labeler.yml index 7ee9f3917a901..e27337440aad3 100644 --- a/.github/workflows/title-only-labeler.yml +++ b/.github/workflows/title-only-labeler.yml @@ -8,7 +8,11 @@ permissions: jobs: triage: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU", + "JobId=title-only-labeler-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - uses: github/issue-labeler@v3.4 with: diff --git a/.github/workflows/web.yml b/.github/workflows/web.yml index 016feab5e0d94..6ae25ccc0bf3e 100644 --- a/.github/workflows/web.yml +++ b/.github/workflows/web.yml @@ -22,7 +22,7 @@ jobs: commit_sha: ${{ steps.extract_commit.outputs.commit_sha }} steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: true diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index eee98332056f6..febc84ad4c86d 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -19,7 +19,11 @@ on: jobs: build_onnxruntime_web: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-WEBGPU-A10"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Win2022-WEBGPU-A10", + "JobId=build_onnxruntime_web-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] env: webgpu_commandline_extra_flags: "--chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de" @@ -29,7 +33,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_build_x64_asan.yml b/.github/workflows/windows_build_x64_asan.yml index 05fd4acd4de9a..265ea2bf8dcab 100644 --- a/.github/workflows/windows_build_x64_asan.yml +++ b/.github/workflows/windows_build_x64_asan.yml @@ -14,12 +14,16 @@ concurrency: jobs: build_x64: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-vs2022-mms", + "JobId=windows-build-x64-asan-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 300 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index fd5b65eb039a3..5821dcf6cebf7 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -19,9 +19,13 @@ concurrency: jobs: build: name: Windows GPU CUDA CI Pipeline - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-vs2022-latest", + "JobId=windows-cuda-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' @@ -150,9 +154,13 @@ jobs: name: Windows GPU CUDA CI Pipeline Test Job needs: build timeout-minutes: 300 - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "JobId=windows-cuda-test-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' diff --git a/.github/workflows/windows_dml.yml b/.github/workflows/windows_dml.yml index e8ee7751348b4..de5eebdc86da0 100644 --- a/.github/workflows/windows_dml.yml +++ b/.github/workflows/windows_dml.yml @@ -25,9 +25,13 @@ jobs: ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "JobId=windows-dml-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 # Fetch all history for all tags and branches submodules: 'none' diff --git a/.github/workflows/windows_openvino.yml b/.github/workflows/windows_openvino.yml index b608c0879aa45..d90199978b969 100644 --- a/.github/workflows/windows_openvino.yml +++ b/.github/workflows/windows_openvino.yml @@ -18,7 +18,11 @@ concurrency: jobs: BUILD_OPENVINO_EP: name: Windows OpenVINO CI Pipeline - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-vs2022-latest", + "JobId=BUILD_OPENVINO_EP-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 240 env: AZCOPY_AUTO_LOGIN_TYPE: MSI @@ -31,7 +35,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none @@ -47,12 +51,12 @@ jobs: with: architecture: x64 - - name: Download OpenVINO Toolkit v2025.2.0 + - name: Download OpenVINO Toolkit v2025.3.0 env: - OpenVINOVersion: 2025.2.0 + OpenVINOVersion: 2025.3.0 shell: pwsh run: | - $Url ="https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.2/windows/openvino_toolkit_windows_2025.2.0.19140.c01cd93e24d_x86_64.zip" + $Url ="https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.3/windows/openvino_toolkit_windows_2025.3.0.19807.44526285f24_x86_64.zip" $OutputPath = "$env:RUNNER_TEMP\openvino.zip" $ExtractPath = "$env:RUNNER_TEMP\openvino-v$env:OpenVINOVersion" $TempExtractPath = "$env:RUNNER_TEMP\openvino_temp" @@ -95,7 +99,7 @@ jobs: shell: pwsh # Use $GITHUB_ENV to set the variable for subsequent steps run: | - $openVinoRootDir = Join-Path $env:RUNNER_TEMP "openvino-v2025.2.0" + $openVinoRootDir = Join-Path $env:RUNNER_TEMP "openvino-v2025.3.0" echo "OpenVINORootDir=$openVinoRootDir" >> $env:GITHUB_ENV - name: Print OpenVINORootDir after downloading OpenVINO diff --git a/.github/workflows/windows_qnn_x64.yml b/.github/workflows/windows_qnn_x64.yml index 4f0b50e65df6e..7b4c5663de96c 100644 --- a/.github/workflows/windows_qnn_x64.yml +++ b/.github/workflows/windows_qnn_x64.yml @@ -18,10 +18,14 @@ concurrency: jobs: build_test_qnn_ep: name: Windows x64 QNN CI Pipeline (${{ matrix.QnnLibKind }}) - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-vs2022-latest", + "JobId=build_test_qnn_ep-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 120 strategy: - matrix: + matrix: QnnLibKind: [shared_lib, static_lib] env: AZCOPY_AUTO_LOGIN_TYPE: MSI @@ -31,7 +35,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 @@ -50,7 +54,7 @@ jobs: azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/qnnsdk/qnn-v2.39.0.250926 . dir shell: pwsh - + - name: Set QNN_SDK_ROOT environment variable shell: pwsh run: | diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index 229efb01f0018..5823194af4a1f 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -19,9 +19,13 @@ concurrency: jobs: build: name: Windows GPU TensorRT CI Pipeline - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-vs2022-latest", + "JobId=windows-tensorrt-build-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' @@ -155,9 +159,13 @@ jobs: name: Windows GPU TensorRT CI Pipeline Test Job needs: build timeout-minutes: 300 - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "JobId=windows-tensorrt-test-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 899a8b66eac7a..e515ffce09833 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -17,7 +17,11 @@ concurrency: jobs: webgpu_build_x64_RelWithDebInfo: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "JobId=webgpu_build_x64_RelWithDebInfo-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 300 strategy: matrix: @@ -34,7 +38,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none @@ -152,11 +156,15 @@ jobs: continue-on-error: true webgpu_external_dawn_build_x64_RelWithDebInfo: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "JobId=webgpu_external_dawn_build_x64_RelWithDebInfo-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 300 steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none @@ -198,7 +206,11 @@ jobs: working-directory: ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo webgpu_minimal_build_edge_build_x64_RelWithDebInfo: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "JobId=webgpu_minimal_build_edge_build_x64_RelWithDebInfo-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 300 env: OrtPackageId: Microsoft.ML.OnnxRuntime @@ -209,7 +221,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none diff --git a/.github/workflows/windows_x64_debug_build_x64_debug.yml b/.github/workflows/windows_x64_debug_build_x64_debug.yml index d62c7130e0ebb..4fa539fc81c3e 100644 --- a/.github/workflows/windows_x64_debug_build_x64_debug.yml +++ b/.github/workflows/windows_x64_debug_build_x64_debug.yml @@ -13,12 +13,16 @@ concurrency: jobs: build_x64_debug: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-vs2022-latest", + "JobId=windows-x64-debug-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 300 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_build_x64_release.yml b/.github/workflows/windows_x64_release_build_x64_release.yml index a2991bb0f1131..6da8d6b1321e7 100644 --- a/.github/workflows/windows_x64_release_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_build_x64_release.yml @@ -13,12 +13,16 @@ concurrency: jobs: build_x64_release: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-vs2022-latest", + "JobId=windows-x64-release-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 300 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml index bb6c5035b0dce..a1cfe789a411a 100644 --- a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml +++ b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml @@ -13,12 +13,16 @@ concurrency: jobs: build_x64_release_ep_generic_interface: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-vs2022-latest", + "JobId=build_x64_release_ep_generic_interface-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 300 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml index 4378231338673..a3200b21cf658 100644 --- a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml @@ -13,12 +13,16 @@ concurrency: jobs: build_x64_release_vitisai: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-vs2022-latest", + "JobId=build_x64_release_vitisai-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 300 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_xnnpack.yml b/.github/workflows/windows_x64_release_xnnpack.yml index b453cd570ac05..ab7fe33e1c137 100644 --- a/.github/workflows/windows_x64_release_xnnpack.yml +++ b/.github/workflows/windows_x64_release_xnnpack.yml @@ -13,12 +13,16 @@ concurrency: jobs: build_x64_release_xnnpack: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-vs2022-latest", + "JobId=build_x64_release_xnnpack-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 300 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -74,7 +78,7 @@ jobs: - name: NuGet restore shell: cmd run: | - nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config + nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - name: Build and Test shell: pwsh diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index d20778d56f60b..70ff4638ddd19 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -13,12 +13,16 @@ concurrency: jobs: build_x86_release: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-vs2022-latest", + "JobId=windows-x86-release-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] timeout-minutes: 300 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/SECURITY.md b/SECURITY.md index 869fdfe2b2469..d8e8bb9ca18a3 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,20 +1,18 @@ - + ## Security -Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations. -If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. ## Reporting Security Issues **Please do not report security vulnerabilities through public GitHub issues.** -Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). -If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). - -You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). +You should receive a response within 24 hours. If for some reason you do not, please follow up using the messaging functionality found at the bottom of the Activity tab on your vulnerability report on [https://msrc.microsoft.com/report/vulnerability](https://msrc.microsoft.com/report/vulnerability/) or via email as described in the instructions at the bottom of [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc) or on MSRC's [FAQ page for reporting an issue](https://www.microsoft.com/en-us/msrc/faqs-report-an-issue). Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: @@ -28,7 +26,7 @@ Please include the requested information listed below (as much as you can provid This information will help us triage your report more quickly. -If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. ## Preferred Languages @@ -36,6 +34,6 @@ We prefer all communications to be in English. ## Policy -Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). diff --git a/SUPPORT.md b/SUPPORT.md new file mode 100644 index 0000000000000..a3d47db2c4e13 --- /dev/null +++ b/SUPPORT.md @@ -0,0 +1,13 @@ +# Support + +## How to file issues and get help + +This project uses [GitHub Issues](https://github.com/Microsoft/onnxruntime/issues) to track bugs and feature requests. Please search the existing +issues before filing new issues to avoid duplicates. For new issues, file your bug or +feature request as a new Issue. + +For help and questions about using this project, please use [GitHub Discussions](https://github.com/microsoft/onnxruntime/discussions). + +## Microsoft Support Policy + +Support for this project is limited to the resources listed above. diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 7b2bbdd2094d1..fbd9f9a95f601 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -5806,41 +5806,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. _____ -composable_kernel - -https://github.com/ROCmSoftwarePlatform/composable_kernel - -Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang) -Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang) -Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan) -Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang) -Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah) -Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) -Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) - -SPDX-License-Identifier: MIT -Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -_____ - neural-speed https://github.com/intel/neural-speed diff --git a/cmake/external/composable_kernel.cmake b/cmake/external/composable_kernel.cmake deleted file mode 100644 index 826bb7c468a02..0000000000000 --- a/cmake/external/composable_kernel.cmake +++ /dev/null @@ -1,66 +0,0 @@ -set(PATCH_CLANG ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch) -set(PATCH_GFX12X ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Add_gfx12x_support.patch) - -include(FetchContent) -onnxruntime_fetchcontent_declare(composable_kernel - URL ${DEP_URL_composable_kernel} - URL_HASH SHA1=${DEP_SHA1_composable_kernel} - PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_CLANG} && - ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_GFX12X} - EXCLUDE_FROM_ALL -) - -FetchContent_GetProperties(composable_kernel) -if(NOT composable_kernel_POPULATED) - FetchContent_Populate(composable_kernel) - set(GPU_TARGETS ${CMAKE_HIP_ARCHITECTURES}) - set(BUILD_DEV OFF CACHE BOOL "Disable -Weverything, otherwise, error: 'constexpr' specifier is incompatible with C++98 [-Werror,-Wc++98-compat]" FORCE) - # Exclude i8 device gemm instances due to excessive long compilation time and not being used - set(DTYPES fp32 fp16 bf16 fp8) - set(INSTANCES_ONLY ON) - add_subdirectory(${composable_kernel_SOURCE_DIR} ${composable_kernel_BINARY_DIR} EXCLUDE_FROM_ALL) - - add_library(onnxruntime_composable_kernel_includes INTERFACE) - target_include_directories(onnxruntime_composable_kernel_includes INTERFACE - ${composable_kernel_SOURCE_DIR}/include - ${composable_kernel_BINARY_DIR}/include - ${composable_kernel_SOURCE_DIR}/library/include) - target_compile_definitions(onnxruntime_composable_kernel_includes INTERFACE __fp32__ __fp16__ __bf16__) - - execute_process( - COMMAND ${Python3_EXECUTABLE} ${composable_kernel_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --list_blobs ${composable_kernel_BINARY_DIR}/blob_list.txt - COMMAND_ERROR_IS_FATAL ANY - ) - file(STRINGS ${composable_kernel_BINARY_DIR}/blob_list.txt generated_fmha_srcs) - add_custom_command( - OUTPUT ${generated_fmha_srcs} - COMMAND ${Python3_EXECUTABLE} ${composable_kernel_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py --output_dir ${composable_kernel_BINARY_DIR} - DEPENDS ${composable_kernel_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py ${composable_kernel_BINARY_DIR}/blob_list.txt - ) - set_source_files_properties(${generated_fmha_srcs} PROPERTIES LANGUAGE HIP GENERATED TRUE) - add_custom_target(gen_fmha_srcs DEPENDS ${generated_fmha_srcs}) # dummy target for dependencies - # code generation complete - - set(fmha_srcs - ${generated_fmha_srcs} - ${composable_kernel_SOURCE_DIR}/example/ck_tile/01_fmha/fmha_fwd.cpp - ${composable_kernel_SOURCE_DIR}/example/ck_tile/01_fmha/fmha_fwd.hpp - ${composable_kernel_SOURCE_DIR}/example/ck_tile/01_fmha/bias.hpp - ${composable_kernel_SOURCE_DIR}/example/ck_tile/01_fmha/mask.hpp - ) - add_library(onnxruntime_composable_kernel_fmha STATIC EXCLUDE_FROM_ALL ${generated_fmha_srcs}) - target_link_libraries(onnxruntime_composable_kernel_fmha PUBLIC onnxruntime_composable_kernel_includes) - target_include_directories(onnxruntime_composable_kernel_fmha PUBLIC ${composable_kernel_SOURCE_DIR}/example/ck_tile/01_fmha) - add_dependencies(onnxruntime_composable_kernel_fmha gen_fmha_srcs) - - # ck tile only supports MI200+ GPUs at the moment - get_target_property(archs onnxruntime_composable_kernel_fmha HIP_ARCHITECTURES) - string(REPLACE "," ";" archs "${archs}") - set(original_archs ${archs}) - list(FILTER archs INCLUDE REGEX "(gfx942|gfx90a)") - if (NOT original_archs EQUAL archs) - message(WARNING "ck tile only supports archs: ${archs} among the originally specified ${original_archs}") - endif() - set_target_properties(onnxruntime_composable_kernel_fmha PROPERTIES HIP_ARCHITECTURES "${archs}") -endif() diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index e1d98109208d4..1dcc7553fd608 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -42,7 +42,7 @@ function(get_c_cxx_api_headers HEADERS_VAR) foreach(f ${ONNXRUNTIME_PROVIDER_NAMES}) # The header files in include/onnxruntime/core/providers/cuda directory cannot be flattened to the same directory # with onnxruntime_c_api.h . Most other EPs probably also do not work in this way. - if((NOT f STREQUAL cuda) AND (NOT f STREQUAL rocm)) + if(NOT f STREQUAL cuda) file(GLOB _provider_headers CONFIGURE_DEPENDS "${REPO_ROOT}/include/onnxruntime/core/providers/${f}/*.h" ) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index e449bb107e77b..1456c8caa8993 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -104,10 +104,6 @@ endif() if(onnxruntime_USE_CANN) target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CANN_HOME}/include) endif() -if(onnxruntime_USE_ROCM) - target_compile_options(onnxruntime_pybind11_state PUBLIC -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1) - target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) -endif() if (onnxruntime_USE_NCCL) target_include_directories(onnxruntime_pybind11_state PRIVATE ${NCCL_INCLUDE_DIRS}) endif() @@ -774,7 +770,6 @@ endif() if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS|visionOS|tvOS" AND NOT CMAKE_SYSTEM_NAME STREQUAL "Android" - AND NOT onnxruntime_USE_ROCM AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD @@ -1044,16 +1039,6 @@ if (onnxruntime_USE_CANN) ) endif() -if (onnxruntime_USE_ROCM) - add_custom_command( - TARGET onnxruntime_pybind11_state POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy - $ - $ - $/onnxruntime/capi/ - ) -endif() - if (onnxruntime_USE_DML) if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(dml_shared_lib_path ${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}-win/${DML_SHARED_LIB}) diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index f81a7a9726b76..86e5810952a09 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -73,6 +73,3 @@ if (NOT onnxruntime_BUILD_SHARED_LIB) FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() -if (onnxruntime_USE_NCCL AND onnxruntime_USE_ROCM) - add_dependencies(onnxruntime_session generate_hipified_files) -endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 4913d38939792..9c6551ad5e792 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -677,7 +677,7 @@ if(onnxruntime_USE_ARMNN) endif() set(ONNXRUNTIME_TEST_STATIC_PROVIDER_LIBS - # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime. + # CUDA, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime. # QNN EP can be built as either a dynamic and static libs. ${PROVIDERS_NNAPI} ${PROVIDERS_VSINPU} @@ -2094,6 +2094,50 @@ if (onnxruntime_BUILD_SHARED_LIB AND set_target_properties(example_plugin_ep_virt_gpu PROPERTIES FOLDER "ONNXRuntimeTest") source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_example_plugin_ep_virt_gpu_src}) + # + # example_plugin_ep_kernel_registry + # + set(onnxruntime_autoep_test_example_plugin_ep_kernel_registry_src + "${TEST_SRC_DIR}/autoep/library/plugin_ep_utils.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_lib_entry.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc") + onnxruntime_add_shared_library_module(example_plugin_ep_kernel_registry ${onnxruntime_autoep_test_example_plugin_ep_kernel_registry_src}) + target_include_directories(example_plugin_ep_kernel_registry PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) + target_link_libraries(example_plugin_ep_kernel_registry PRIVATE onnxruntime ${GSL_TARGET}) + + if(UNIX) + if (APPLE) + set(ONNXRUNTIME_EXAMPLE_PLUGIN_EP_KERNEL_REGISTRY_LINK_FLAG "-Xlinker -dead_strip") + elseif (NOT CMAKE_SYSTEM_NAME MATCHES "AIX") + string(CONCAT ONNXRUNTIME_EXAMPLE_PLUGIN_EP_KERNEL_REGISTRY_LINK_FLAG + "-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_lib.lds " + "-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") + endif() + else() + set(ONNXRUNTIME_EXAMPLE_PLUGIN_EP_KERNEL_REGISTRY_LINK_FLAG + "-DEF:${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_lib.def") + endif() + + set_property(TARGET example_plugin_ep_kernel_registry APPEND_STRING PROPERTY LINK_FLAGS + ${ONNXRUNTIME_EXAMPLE_PLUGIN_EP_KERNEL_REGISTRY_LINK_FLAG}) + + set_target_properties(example_plugin_ep_kernel_registry PROPERTIES FOLDER "ONNXRuntimeTest") + source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_example_plugin_ep_kernel_registry_src}) + # # test library # @@ -2129,7 +2173,7 @@ if (onnxruntime_BUILD_SHARED_LIB AND TARGET onnxruntime_autoep_test SOURCES ${onnxruntime_autoep_test_SRC} ${onnxruntime_unittest_main_src} LIBS ${onnxruntime_autoep_test_LIBS} - DEPENDS ${all_dependencies} example_plugin_ep example_plugin_ep_virt_gpu + DEPENDS ${all_dependencies} example_plugin_ep example_plugin_ep_virt_gpu example_plugin_ep_kernel_registry ) endif() diff --git a/cmake/patches/composable_kernel/Add_gfx12x_support.patch b/cmake/patches/composable_kernel/Add_gfx12x_support.patch deleted file mode 100644 index ef529184d2ed8..0000000000000 --- a/cmake/patches/composable_kernel/Add_gfx12x_support.patch +++ /dev/null @@ -1,2280 +0,0 @@ -diff --git a/CMakeLists.txt b/CMakeLists.txt -index bc326c8b5..db5ad5052 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -117,7 +117,7 @@ else() - add_definitions(-DPROFILER_ONLY) - set(GPU_TARGETS "" CACHE STRING "" FORCE) - if(GPU_TARGETS) -- message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11") -+ message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12") - endif() - if(GPU_ARCH MATCHES "gfx90") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a") -@@ -127,8 +127,10 @@ else() - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") - elseif(GPU_ARCH MATCHES "gfx11") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") -+ elseif(GPU_ARCH MATCHES "gfx12") -+ rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201") - else() -- message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11") -+ message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12") - endif() - set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) - endif() -diff --git a/Jenkinsfile b/Jenkinsfile -index 75800bfc9..b72e2ca4e 100644 ---- a/Jenkinsfile -+++ b/Jenkinsfile -@@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){ - - def variant = env.STAGE_NAME - def retimage -+ - gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { - try { - (retimage, image) = getDockerImage(conf) -@@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM - - pipeline { - agent none -- triggers { -- parameterizedCron(CRON_SETTINGS) -- } - options { - parallelsAlwaysFailFast() - } -diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake -index 8654170b3..42070051b 100644 ---- a/cmake/EnableCompilerWarnings.cmake -+++ b/cmake/EnableCompilerWarnings.cmake -@@ -66,7 +66,7 @@ else() - -Wunreachable-code - -Wunused - -Wno-reserved-identifier -- -Werror -+ -Werror - -Wno-option-ignored - -Wsign-compare - -Wno-extra-semi-stmt -diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp -index 8c52e4f7d..f8afe8d6d 100644 ---- a/example/01_gemm/gemm_wmma_fp16.cpp -+++ b/example/01_gemm/gemm_wmma_fp16.cpp -@@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa - - // clang-format off - using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle -- < ALayout, -- BLayout, -- CLayout, -- ADataType, -+ < ALayout, -+ BLayout, -+ CLayout, -+ ADataType, - BDataType, -- CDataType, -- AccDataType, -- CShuffleDataType, -- AElementOp, -- BElementOp, -- CElementOp, -- GemmDefault, -+ CDataType, -+ AccDataType, -+ CShuffleDataType, -+ AElementOp, -+ BElementOp, -+ CElementOp, -+ GemmDefault, - 1, // Prefetch stage - 128, // BlockSize - 64, // MPerBlock - 128, // NPerBlock - 64, // KPerBlock -- 8, // K1 -+ 2, // K1 - 16, // MPerWmma - 16, // NPerWmma - 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave - 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave -- S<4, 32, 1>, -- S<1, 0, 2>, -- S<1, 0, 2>, -- 2, -- 8, -- 8, -- true, -- S<4, 32, 1>, -- S<1, 0, 2>, -- S<1, 0, 2>, -- 2, -- 8, -- 8, -- true, -+ S<4, 32, 1>, -+ S<1, 0, 2>, -+ S<1, 0, 2>, -+ 2, -+ 2, -+ 2, -+ true, -+ S<4, 32, 1>, -+ S<1, 0, 2>, -+ S<1, 0, 2>, -+ 2, -+ 2, -+ 2, -+ true, - 1, // C shuffle (M Repeat) Per store - 1, // C shuffle (N Repeat) Per store -- S<1, 32, 1, 4>, -+ S<1, 32, 1, 4>, - 8>; - // clang-format on - -diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc -index b04e4e53a..cb15186c3 100644 ---- a/example/01_gemm/run_gemm_example.inc -+++ b/example/01_gemm/run_gemm_example.inc -@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); - break; - case 4: -- ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); -+ ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); - ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); - break; - case 5: -diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt -index ab19f819e..be47665a2 100644 ---- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt -+++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt -@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32) - set(target 1) - endif() --endforeach() -\ No newline at end of file -+endforeach() -diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp -index 2bbf430c4..f556be887 100644 ---- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp -+++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp -@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN = - 2, - 4, - 4, -- true, -+ false, - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 4, - 4, -- true, -+ false, - 1, - 1, - S<1, 64, 1, 2>, -diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp -index 4c92c5497..fac19f8b5 100644 ---- a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp -+++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp -@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial - #define CK_MHA_USE_WAVE_1 - #define CK_MHA_USE_WAVE_2 - #define CK_MHA_USE_WAVE_4 --#define CK_MHA_USE_WAVE_8 -+//#define CK_MHA_USE_WAVE_8 - using DeviceMHAFactory = - std::tuple< - #ifdef CK_MHA_USE_WAVE_1 -@@ -277,10 +277,10 @@ using DeviceMHAFactory = - S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, - // CShuffleBlockTransfer MN - 1, 1, S<1, 64, 1, 2>, 8, -- MaskingSpec>, -+ MaskingSpec> - #endif - #ifdef CK_MHA_USE_WAVE_8 -- ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< -+ ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< - NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, - ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, - AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, -diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp -index 8e037272b..d463cc871 100644 ---- a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp -+++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp -@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial - #define CK_MHA_USE_WAVE_1 - #define CK_MHA_USE_WAVE_2 - #define CK_MHA_USE_WAVE_4 --#define CK_MHA_USE_WAVE_8 -+//#define CK_MHA_USE_WAVE_8 - using DeviceMHAFactory = - std::tuple< - #ifdef CK_MHA_USE_WAVE_1 -@@ -277,10 +277,10 @@ using DeviceMHAFactory = - S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, - // CShuffleBlockTransfer MN - 1, 1, S<1, 64, 1, 2>, 8, -- MaskingSpec>, -+ MaskingSpec> - #endif - #ifdef CK_MHA_USE_WAVE_8 -- ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< -+ ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< - NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, - ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, - AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, -diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt -index 5465adb77..7534bff3b 100644 ---- a/example/CMakeLists.txt -+++ b/example/CMakeLists.txt -@@ -60,7 +60,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) - endforeach() - #Do not build any WMMA examples if gfx11 targets are not on the list - foreach(source IN LISTS FILE_NAME) -- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") -+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") - message("removing wmma example ${source} ") - list(REMOVE_ITEM FILE_NAME "${source}") - endif() -@@ -134,7 +134,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) - endforeach() - #Do not build any WMMA examples if gfx11 targets are not on the list - foreach(source IN LISTS FILE_NAME) -- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") -+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") - message("removing wmma example ${source} ") - list(REMOVE_ITEM FILE_NAME "${source}") - endif() -diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp -index 55f562061..69a7abf62 100644 ---- a/include/ck/ck.hpp -+++ b/include/ck/ck.hpp -@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) - #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) - #define __gfx11__ - #endif -+#if defined(__gfx1200__) || defined(__gfx1201__) -+#define __gfx12__ -+#endif - - // buffer resource - #ifndef __HIP_DEVICE_COMPILE__ // for host code -@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) - #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 - #elif defined(__gfx103__) - #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 --#elif defined(__gfx11__) -+#elif defined(__gfx11__) || defined(__gfx12__) - #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 - #endif - -@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) - #define CK_USE_AMD_V_FMAC_F32 - #define CK_USE_AMD_V_DOT2_F32_F16 - #define CK_USE_AMD_V_DOT4_I32_I8 --#elif defined(__gfx11__) -+#elif defined(__gfx11__) || defined(__gfx12__) - #define CK_USE_AMD_V_FMAC_F32 - #define CK_USE_AMD_V_DOT2_F32_F16 - #define CK_USE_AMD_V_DOT4_I32_I8_GFX11 -@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) - #define CK_USE_AMD_MFMA_GFX940 - #endif - --// WMMA instruction --#ifndef __HIP_DEVICE_COMPILE__ // for host code --#define CK_USE_AMD_WMMA --#elif defined(__gfx11__) // for GPU code --#define CK_USE_AMD_WMMA --#endif -- - // buffer load - #define CK_USE_AMD_BUFFER_LOAD 1 - -diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp -index 116bb3ea0..83af2efe8 100644 ---- a/include/ck/host_utility/device_prop.hpp -+++ b/include/ck/host_utility/device_prop.hpp -@@ -84,4 +84,9 @@ inline bool is_gfx11_supported() - ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; - } - -+inline bool is_gfx12_supported() -+{ -+ return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; -+} -+ - } // namespace ck -diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp -index f8ee283c6..7eb7d42eb 100644 ---- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp -+++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp -@@ -13,6 +13,504 @@ - - namespace ck { - -+#ifdef __gfx12__ -+template -+/* Option: Read from LDS, big buffer hold all threads required data -+ * Source -+ * A: K0PerBlock x MPerBlock x K1 -+ * B: K0PerBlock x NPerBlock x K1 -+ * Destination -+ * C, non-transpose -+ * thread level: MRepeat x NRepeat x MAccVgprs -+ * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs -+ * KPACK == WMMA_K = 16 -+ * -+ * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) -+ * Source: -+ * A(if skip LDS): MRepeat x KPack -+ * B(if skip LDS): NRepeat x KPack -+ * Destination -+ * C, non-transpose -+ * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs -+ */ -+struct BlockwiseGemmWMMA -+{ -+ static constexpr auto I0 = Number<0>{}; -+ static constexpr auto I1 = Number<1>{}; -+ static constexpr auto I2 = Number<2>{}; -+ static constexpr auto I3 = Number<3>{}; -+ static constexpr auto I4 = Number<4>{}; -+ static constexpr auto I5 = Number<5>{}; -+ static constexpr auto WmmaK = Number<16>{}; -+ -+ using ThisThreadBlock = ThisThreadBlock; -+ -+ // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. -+ static constexpr index_t WaveSize = 32; -+ -+ // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer -+ // When not use LDS, each Row read half of whole data from source buffer, exchange the data via -+ // permutation -+ static constexpr index_t A_KRow = 2; -+ static constexpr index_t B_KRow = 2; -+ -+ static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); -+ static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); -+ -+ static constexpr auto wmma_gemm = -+ WmmaGemm{}; -+ -+ static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); -+ static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); -+ -+ StaticBufferTupleOfVector -+ c_thread_buf_; -+ -+ __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } -+ -+ __device__ static auto GetWaveIdx() -+ { -+ const index_t thread_id = ThisThreadBlock::GetThreadId(); -+ -+ constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( -+ make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), -+ make_tuple(Sequence<0, 1, 2>{}), -+ make_tuple(Sequence<0>{})); -+ -+ return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); -+ } -+ -+ // Default, Block buffer in LDS, thread level offset enabled -+ __device__ static auto CalculateAThreadOriginDataIndex() -+ { -+ if constexpr(AEnableLds) -+ { -+ const auto wave_idx = GetWaveIdx(); -+ const auto waveId_m = wave_idx[I0]; -+ const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); -+ -+ // |KRepeat |MRepeat|MWave |KRow |MLane |KPack -+ return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0); -+ } -+ else -+ { -+ return make_tuple(0, 0, 0, 0, 0, 0); -+ } -+ } -+ -+ __device__ static auto CalculateBThreadOriginDataIndex() -+ { -+ if constexpr(BEnableLds) -+ { -+ const auto wave_idx = GetWaveIdx(); -+ const auto waveId_n = wave_idx[I1]; -+ const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); -+ -+ // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack -+ return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0); -+ } -+ else -+ { -+ return make_tuple(0, 0, 0, 0, 0, 0); -+ } -+ } -+ -+ template -+ __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) -+ { -+ const auto wave_idx = GetWaveIdx(); -+ -+ const auto waveId_m = wave_idx[I0]; -+ const auto waveId_n = wave_idx[I1]; -+ -+ const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); -+ -+ constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( -+ make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), -+ make_tuple(Sequence<0>{}), -+ make_tuple(Sequence<0, 1, 2>{})); -+ -+ constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( -+ make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), -+ make_tuple(Sequence<0>{}), -+ make_tuple(Sequence<0, 1, 2>{})); -+ -+ const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( -+ make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; -+ const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( -+ make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; -+ -+ return make_tuple(c_thread_m, c_thread_n); -+ } -+ -+ template -+ __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) -+ { -+ const auto wave_idx = GetWaveIdx(); -+ -+ const auto waveId_m = wave_idx[I0]; -+ const auto waveId_n = wave_idx[I1]; -+ -+ const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); -+ -+ return make_tuple( -+ Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); -+ } -+ -+ using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); -+ __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), -+ Tuple6 b_origin = CalculateBThreadOriginDataIndex()) -+ : a_thread_copy_(a_origin), b_thread_copy_(b_origin) -+ { -+ static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), -+ "wrong! Desc should be known at compile-time"); -+ -+ static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, -+ "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); -+ -+ static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && -+ NPerBlock % (NPerWMMA * NRepeat) == 0, -+ "wrong!"); -+ } -+ -+ // transposed WMMA output C' = B' * A' -+ __host__ __device__ static constexpr auto -+ GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() -+ { -+ constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = -+ wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); -+ -+ constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; -+ -+ return make_naive_tensor_descriptor_packed( -+ // |MRepeat |MWave |MSubGroup |NRepeat |NWave -+ // |NThreadPerSubGroup |MAccVgprs -+ make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); -+ } -+ -+ // Thread level, register decriptor. Vector-write -+ __host__ __device__ static constexpr auto -+ GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() -+ { -+ constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = -+ wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); -+ -+ constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; -+ constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; -+ return make_naive_tensor_descriptor( -+ // |MRepeat |MWave |MSubGroup |NRepeat |NWave -+ // |NThreadPerSubGroup |MAccVgprs -+ make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), -+ make_tuple(Number{} * MAccVgprs * AccStride, -+ Number{} * MAccVgprs * AccStride, -+ Number{} * MAccVgprs * AccStride, -+ MAccVgprs * AccStride, -+ MAccVgprs * AccStride, -+ MAccVgprs * AccStride, -+ AccStride)); -+ } -+ -+ template -+ __host__ __device__ static constexpr auto -+ MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( -+ const CGridDesc_M_N& c_grid_desc_m_n) -+ { -+ const auto M = c_grid_desc_m_n.GetLength(I0); -+ const auto N = c_grid_desc_m_n.GetLength(I1); -+ -+ const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = -+ transform_tensor_descriptor( -+ c_grid_desc_m_n, -+ make_tuple( -+ make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), -+ make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), -+ make_tuple(Sequence<0>{}, Sequence<1>{}), -+ make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); -+ -+ return wmma_gemm -+ .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( -+ c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); -+ } -+ -+ // transposed WMMA output C' = B' * A' -+ __host__ __device__ static constexpr auto -+ GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() -+ { -+ constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = -+ make_naive_tensor_descriptor_packed(make_tuple(Number{}, -+ Number{}, -+ Number{}, -+ Number{}, -+ Number{}, -+ Number{})); -+ -+ return wmma_gemm -+ .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( -+ c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); -+ } -+ -+ // Provide dimension size -+ __host__ __device__ static constexpr auto -+ GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() -+ { -+ constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = -+ make_naive_tensor_descriptor_packed(make_tuple(Number{}, -+ Number{}, -+ Number{}, -+ Number{}, -+ Number{}, -+ Number{})); -+ -+ return wmma_gemm -+ .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( -+ c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); -+ } -+ -+ // Describe how data allocated in thread copy src buffer -+ // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma -+ static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; -+ static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; -+ -+ template -+ __device__ void Run(const ABlockBuffer& a_block_buf, -+ const BBlockBuffer& b_block_buf, -+ CThreadBuffer& c_thread_buf) const -+ { -+ auto a_thread_buf = make_static_buffer( -+ a_thread_desc_.GetElementSpaceSize()); -+ auto b_thread_buf = make_static_buffer( -+ b_thread_desc_.GetElementSpaceSize()); -+ -+ static_assert(KPack % (A_K1 * A_KRow) == 0, ""); -+ static_assert(KPack % (B_K1 * B_KRow) == 0, ""); -+ -+ // basic intrinsic to determine loopover direction -+ if constexpr(MRepeat < NRepeat) -+ { -+ static_for<0, KPerBlock / KPack, 1>{}( -+ [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... -+ static_for<0, MRepeat, 1>{}([&](auto m0) { -+ // read A -+ a_thread_copy_.Run( -+ a_block_desc_k0_m0_m1_m2_k1, -+ make_tuple(Number{}, m0, I0, I0, I0, I0), -+ a_block_buf, -+ a_thread_desc_, -+ make_tuple(I0, m0, I0, I0, I0, I0), -+ a_thread_buf); -+ -+ static_for<0, NRepeat, 1>{}([&](auto n0) { -+ // read B -+ b_thread_copy_.Run( -+ b_block_desc_k0_n0_n1_n2_k1, -+ make_tuple(Number{}, n0, I0, I0, I0, I0), -+ b_block_buf, -+ b_thread_desc_, -+ make_tuple(I0, n0, I0, I0, I0, I0), -+ b_thread_buf); -+ -+ vector_type a_thread_vec; -+ vector_type b_thread_vec; -+ -+ static_for<0, KPack / A_KRow, 1>{}([&](auto i) { -+ a_thread_vec.template AsType()(i) = -+ a_thread_buf[Number{}]; -+ }); -+ -+ static_for<0, KPack / B_KRow, 1>{}([&](auto i) { -+ b_thread_vec.template AsType()(i) = -+ b_thread_buf[Number{}]; -+ }); -+ -+ using wmma_input_type_a = -+ typename vector_type::type; -+ using wmma_input_type_b = -+ typename vector_type::type; -+ -+ constexpr index_t c_offset = -+ c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); -+ -+ wmma_gemm.template Run( -+ a_thread_vec.template AsType(), -+ b_thread_vec.template AsType(), -+ c_thread_buf.GetVectorTypeReference(Number{})); -+ }); -+ }); -+ }); -+ } -+ else -+ { -+ static_for<0, NRepeat, 1>{}([&](auto n0) { -+ static_for<0, MRepeat, 1>{}([&](auto m0) { -+ static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of -+ // k=0,kpack*1, .. -+ // read B -+ b_thread_copy_.Run( -+ b_block_desc_k0_n0_n1_n2_k1, -+ make_tuple(Number{}, n0, I0, I0, I0, I0), -+ b_block_buf, -+ b_thread_desc_, -+ make_tuple(I0, n0, I0, I0, I0, I0), -+ b_thread_buf); -+ // read A -+ a_thread_copy_.Run( -+ a_block_desc_k0_m0_m1_m2_k1, -+ make_tuple(Number{}, m0, I0, I0, I0, I0), -+ a_block_buf, -+ a_thread_desc_, -+ make_tuple(I0, m0, I0, I0, I0, I0), -+ a_thread_buf); -+ -+ vector_type a_thread_vec; -+ vector_type b_thread_vec; -+ -+ static_for<0, KPack / A_KRow, 1>{}([&](auto i) { -+ a_thread_vec.template AsType()(i) = -+ a_thread_buf[Number{}]; -+ }); -+ -+ static_for<0, KPack / B_KRow, 1>{}([&](auto i) { -+ b_thread_vec.template AsType()(i) = -+ b_thread_buf[Number{}]; -+ }); -+ -+ using wmma_input_type_a = -+ typename vector_type::type; -+ using wmma_input_type_b = -+ typename vector_type::type; -+ -+ constexpr index_t c_offset = -+ c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); -+ -+ wmma_gemm.template Run( -+ a_thread_vec.template AsType(), -+ b_thread_vec.template AsType(), -+ c_thread_buf.GetVectorTypeReference(Number{})); -+ }); -+ }); -+ }); -+ } -+ } -+ -+ protected: -+ static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( -+ make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), -+ make_tuple(Number{}, -+ Number{}, -+ Number{}, -+ Number{}, -+ Number{}, -+ Number<1>{})); -+ -+ static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( -+ make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), -+ make_tuple(Number{}, -+ Number{}, -+ Number{}, -+ Number{}, -+ Number{}, -+ Number<1>{})); -+ -+ // C[M, N, NumRegWMMA] -+ static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( -+ make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); -+ -+ template -+ struct AThreadCopySelector; -+ -+ template <> -+ struct AThreadCopySelector -+ { -+ using type = -+ ThreadwiseTensorSliceTransfer_v4, -+ Sequence<0, 1, 2, 3, 4, 5>, -+ 5, -+ A_K1, -+ A_K1>; -+ }; -+ -+ template <> -+ struct AThreadCopySelector -+ { -+ using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< -+ FloatA, -+ FloatA, -+ decltype(a_block_desc_k0_m0_m1_m2_k1), -+ decltype(a_thread_desc_), -+ tensor_operation::element_wise::PassThrough, -+ Sequence, -+ Sequence<0, 1, 2, 3, 4, 5>, -+ 5, -+ A_K1, -+ false>; -+ }; -+ -+ template -+ struct BThreadCopySelector; -+ -+ template <> -+ struct BThreadCopySelector -+ { -+ using type = -+ ThreadwiseTensorSliceTransfer_v4, -+ Sequence<0, 1, 2, 3, 4, 5>, -+ 5, -+ B_K1, -+ B_K1>; -+ }; -+ -+ template <> -+ struct BThreadCopySelector -+ { -+ using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< -+ FloatB, -+ FloatB, -+ decltype(b_block_desc_k0_n0_n1_n2_k1), -+ decltype(b_thread_desc_), -+ tensor_operation::element_wise::PassThrough, -+ Sequence, -+ Sequence<0, 1, 2, 3, 4, 5>, -+ 5, -+ B_K1, -+ false>; -+ }; -+ -+ typename AThreadCopySelector::type a_thread_copy_; -+ typename BThreadCopySelector::type b_thread_copy_; -+}; -+#else - template ::type a_thread_copy_; - typename BThreadCopySelector::type b_thread_copy_; - }; -+#endif - - } // namespace ck -diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp -index e5e6245cb..1f7d50429 100644 ---- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp -+++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp -@@ -488,7 +488,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 - // sync point. - if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) - { -+#ifdef __gfx12__ -+ asm volatile("\ -+ s_barrier_signal -1 \n \ -+ s_barrier_wait -1 \ -+ " ::); -+#else - asm volatile("s_barrier" ::); -+#endif - __builtin_amdgcn_sched_barrier(0); - } - static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp -index a15759559..ab3f3856a 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp -@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle - static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); - static constexpr auto WmmaK = K1 == 16 ? 32 : 16; - -- static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; -- static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; -+ static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; -+ static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; -+ -+ static constexpr auto AEnableLds_auto = -+ (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true; -+ static constexpr auto BEnableLds_auto = -+ (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true; - - // If true, LDS is used unconditionally - static constexpr auto AEnableLds_manu = false; -@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle - - static bool IsSupportedArgument(const Argument& arg) - { -- if(ck::is_gfx11_supported()) -+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) - { - if constexpr(!(is_same_v || is_same_v)) - { -@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle - } - else - { -- if(!(arg.a_kz_stride_ == 1 && -- arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) -+ if(!(arg.a_kz_stride_ == 1)) - { -- printf("DeviceOp: Vector Access A-k check failure\n"); -- return false; -+ index_t LastK = -+ AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6); -+ if(LastK % ABlockTransferSrcScalarPerVector == 0) -+ { -+ printf("DeviceOp: Vector Access A-k check failure\n"); -+ return false; -+ } - } - } - -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp -index 8fd14afc0..1b487502f 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp -@@ -70,8 +70,9 @@ __global__ void - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2CTileMap block_2_ctile_map) - { --#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ -- defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) -+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ -+ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ -+ defined(__gfx12__)) - - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); -@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD || is_same_v)) - { -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp -index 9d5b74be6..017d28641 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp -@@ -601,9 +601,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle - return false; - } - -- if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" && -- ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" && -- std::is_same::value) -+ if(!ck::is_lds_direct_load_supported() && std::is_same::value) - { - return false; - } -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp -index b84e18130..1edae33be 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp -@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl - { - // check device - if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || -- ck::is_gfx11_supported())) -+ ck::is_gfx11_supported() || ck::is_gfx12_supported())) - { - return false; - } -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp -index bf96324d0..553143e28 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp -@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB || is_same_v || - is_same_v)) -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp -index b1784b385..eb0fb55f5 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp -@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm || is_same_v)) - { -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp -index 93ab8a7e1..a7cc546f5 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp -@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm{}; - -- static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); -- static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); -- static constexpr auto WmmaK = K1 == 16 ? 32 : 16; -- -- static constexpr auto AEnableLds_auto = -- (NWaves == 1 && is_same::value) ? false : true; -+ static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); -+ static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); -+ static constexpr auto WmmaK = K1 == 16 ? 32 : 16; -+ static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; -+ static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; -+ -+ static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) && -+ is_same::value) -+ ? false -+ : true; - static constexpr auto BEnableLds_auto = -- (MWaves == 1 && is_same::value) ? false : true; -+ (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) && -+ is_same::value) -+ ? false -+ : true; - - // If true, LDS is used unconditionally - static constexpr auto AEnableLds_manu = false; -@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v || - is_same_v)) -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp -index 6f74838fb..6bb5d431c 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp -@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle - static bool IsSupportedArgument(const Argument& arg) - { - // check device -- if(ck::is_gfx11_supported()) -+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) - { - if constexpr(!(is_same_v || is_same_v)) - { -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp -index bd264a3c8..7047e1bda 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp -@@ -48,8 +48,9 @@ __global__ void - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) - { --#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ -- defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) -+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ -+ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ -+ defined(__gfx12__)) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp -index 211185dfb..5738be0fb 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp -@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle - static bool IsSupportedArgument(const Argument& arg) - { - // check device -- if(ck::is_gfx11_supported()) -+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) - { - if constexpr(!(is_same_v || is_same_v)) - { -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp -index 7cfbd8a8f..5d5a9de7d 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp -@@ -90,8 +90,9 @@ __global__ void - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) - { --#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ -- defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) -+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ -+ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ -+ defined(__gfx12__)) - // offset base pointer for each work-group - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); -@@ -666,7 +667,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK - - // check device - if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || -- ck::is_gfx103_supported() || ck::is_gfx11_supported())) -+ ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())) - { - return false; - } -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp -index 6a4d97d7d..c65370b51 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp -@@ -107,7 +107,7 @@ __global__ void - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) - { - #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ -- defined(__gfx11__)) -+ defined(__gfx11__) || defined(__gfx12__)) - // offset base pointer for each work-group - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); -@@ -602,7 +602,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd || is_same_v)) - { -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp -index ac392cddc..060a16d1e 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp -@@ -39,8 +39,9 @@ __global__ void - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) - { --#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ -- defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__)) -+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ -+ defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \ -+ defined(__gfx12__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t block_id = get_block_1d_id(); -@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm || is_same_v)) - { -diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp -index 4e14ed3a5..cc88c1a10 100644 ---- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp -+++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp -@@ -60,7 +60,7 @@ __global__ void - bool input_permute, - bool output_permute) - { --#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) -+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - - // clang-format off - // *************************************************** -@@ -165,6 +165,7 @@ __global__ void - ignore = O; - ignore = G0; - ignore = G1; -+ ignore = alpha; - ignore = input_permute; - ignore = output_permute; - #endif // end of if (defined(__gfx11__)) -@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma - - static bool IsSupportedArgument(const RawArg& arg) - { -- if(ck::is_gfx11_supported()) -+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) - { - if constexpr(!(is_same_v || is_same_v)) - { -diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp -index 16717ff81..1754e07e6 100644 ---- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp -+++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp -@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma - if constexpr(B0EnableLds) - { - // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 -- constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); -- constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); -+ constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); -+ constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); -+#ifdef __gfx12__ -+ constexpr auto B_KRow = I2; -+#else - constexpr auto B_KRow = I1; -+#endif - return transform_tensor_descriptor( - B0BlockDesc_{}, -- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), -+ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), -@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma - if constexpr(B1EnableLds) - { - // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 -- constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); -- constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); -+ constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); -+ constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); -+#ifdef __gfx12__ -+ constexpr auto B_LRow = I2; -+#else - constexpr auto B_LRow = I1; -+#endif - return transform_tensor_descriptor( - B1BlockDesc_{}, -- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), -+ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), -diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp -index 499eb7eb0..21dac6f9e 100644 ---- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp -+++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp -@@ -50,7 +50,7 @@ __global__ void - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) - { --#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) -+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; - - GridwiseGemm::template Run(p_a_grid, -@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma - if constexpr(AEnableLds) - { - // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 -- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); -- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); -+ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); -+ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); -+#ifdef __gfx12__ -+ constexpr auto A_KRow = I2; -+#else - constexpr auto A_KRow = I1; -+#endif - return transform_tensor_descriptor( - ABlockDesc_{}, -- make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), -+ make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), -@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma - if constexpr(BEnableLds) - { - // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 -- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); -- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); -+ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); -+ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); -+#ifdef __gfx12__ -+ constexpr auto B_KRow = I2; -+#else - constexpr auto B_KRow = I1; -+#endif - return transform_tensor_descriptor( - BBlockDesc_{}, -- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), -+ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), -diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp -index 82d010a99..fdda649ef 100644 ---- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp -+++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp -@@ -54,7 +54,7 @@ __global__ void - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) - { --#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) -+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - // offset base pointer for each work-group - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); -@@ -147,7 +147,7 @@ __global__ void - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2CTileMap block_2_etile_map) - { --#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) -+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - // printf("entry kernel launch"); - __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; - -@@ -237,7 +237,7 @@ __global__ void - const CDEElementwiseOperation cde_element_op, - const Block2CTileMap block_2_ctile_map) - { --#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) -+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; - - GridwiseOp::template Run(p_a_grid, -@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma - } - else - { -+ constexpr auto A_KRow = I2; - constexpr auto KWmmaPerblock = KPerBlock / WmmaK; -- constexpr auto K0PerWmma = WmmaK / 2 / K1; -+ constexpr auto K0PerWmma = WmmaK / A_KRow / K1; - // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread - return make_naive_tensor_descriptor( - make_tuple(Number{}, -@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma - } - else - { -+ constexpr auto B_KRow = I2; - constexpr auto KWmmaPerblock = KPerBlock / WmmaK; -- constexpr auto K0PerWmma = WmmaK / 2 / K1; -+ constexpr auto K0PerWmma = WmmaK / B_KRow / K1; - // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread - return make_naive_tensor_descriptor( - make_tuple(Number{}, -@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma - if constexpr(AEnableLds) - { - // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 -- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); -- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); -+ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); -+ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); -+#ifdef __gfx12__ -+ constexpr auto A_KRow = I2; -+#else - constexpr auto A_KRow = I1; -+#endif - return transform_tensor_descriptor( - ABlockDesc_{}, -- make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), -+ make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), -@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma - if constexpr(BEnableLds) - { - // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 -- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); -- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); -+ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); -+ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); -+#ifdef __gfx12__ -+ constexpr auto B_KRow = I2; -+#else - constexpr auto B_KRow = I1; -+#endif - return transform_tensor_descriptor( - BBlockDesc_{}, -- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), -+ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), -@@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma - // *Caution Here repeat is shuffle repeat - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() - { -- constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); -- constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); -- - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - make_naive_tensor_descriptor_packed( - make_tuple(I1, -- Number{}, -+ Number{}, - I1, -- Number{})); -+ Number{})); - - return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; - } -@@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma - const auto M = e_grid_desc_m_n.GetLength(I0); - const auto N = e_grid_desc_m_n.GetLength(I1); - -- const auto MBlock = M / MPerBlock; -- const auto NBlock = N / NPerBlock; -+ const auto MBlock = M / MPerBlock; -+ const auto NBlock = N / NPerBlock; -+ - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( - e_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), -diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp -index 8e4117593..4458b9356 100644 ---- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp -+++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp -@@ -45,7 +45,7 @@ __global__ void - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) - { --#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) -+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; - - GridwiseGemm::template Run(p_a_grid, -@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma - } - else - { -+ constexpr auto A_KRow = I2; - constexpr auto KWmmaPerblock = KPerBlock / WmmaK; -- constexpr auto K0PerWmma = WmmaK / 2 / K1; -+ constexpr auto K0PerWmma = WmmaK / A_KRow / K1; - // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread - return make_naive_tensor_descriptor( - make_tuple(Number{}, -@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma - } - else - { -+ -+ constexpr auto B_KRow = I2; - constexpr auto KWmmaPerblock = KPerBlock / WmmaK; -- constexpr auto K0PerWmma = WmmaK / 2 / K1; -+ constexpr auto K0PerWmma = WmmaK / B_KRow / K1; - // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread - return make_naive_tensor_descriptor( - make_tuple(Number{}, -@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma - if constexpr(AEnableLds) - { - // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 -- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); -- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); -+ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); -+ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); -+#ifdef __gfx12__ -+ constexpr auto A_KRow = I2; -+#else - constexpr auto A_KRow = I1; -+#endif -+ - return transform_tensor_descriptor( - ABlockDesc_{}, -- make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), -+ make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), -@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma - if constexpr(BEnableLds) - { - // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 -- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); -- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); -+ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); -+ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); -+#ifdef __gfx12__ -+ constexpr auto B_KRow = I2; -+#else - constexpr auto B_KRow = I1; -+#endif - return transform_tensor_descriptor( - BBlockDesc_{}, -- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), -+ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), -@@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma - c_grid_desc_m_n); - } - -- using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = -- remove_cvref_t; -- using DefaultBlock2CTileMap = -- remove_cvref_t; -- - struct SharedMemTrait - { - // LDS allocation for A and B: be careful of alignment -@@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma - b_block_space_size_aligned * sizeof(BDataType)); - }; - -+ using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = -+ remove_cvref_t; -+ using DefaultBlock2CTileMap = -+ remove_cvref_t; -+ - template - __device__ static void Run(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, -diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp -index 6772524e0..174074990 100644 ---- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp -+++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp -@@ -35,8 +35,9 @@ __global__ void - const Block2ETileMap block_2_tile_map, - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) - { --#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ -- defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) -+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ -+ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ -+ defined(__gfx12__)) - GridwiseTensorRearrangeKernel::Run(in_grid_desc, - p_in_global, - out_grid_desc, -diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp -index bcce930fc..d7a6a3624 100644 ---- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp -+++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp -@@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic - ElementwiseOperation element_op_; - }; - --// Specilized for WMMA -+// Specilized for WMMA-Navi3 - // A single Wave32 is composed by double row - // Data exchange allowed between these two rows - // This RowLane Dst buf will be filled from two Src buf -@@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow - ElementwiseOperation element_op_{}; - }; - -+// Specilized for WMMA-Navi4 -+template ::type = false> -+struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow -+{ -+ static constexpr index_t nDim = SliceLengths::Size(); -+ -+ using Index = MultiIndex; -+ -+ __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx) -+ { -+ static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), -+ "wrong! Desc need to known at compile-time"); -+ -+ static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, -+ "wrong! Not divisible"); -+ ignore = src_idx; -+ } -+ -+ template -+ __device__ void Run(const SrcDesc&, -+ const SrcSliceOriginIdx&, -+ const SrcBuffer& src_buf, -+ const DstDesc&, -+ const DstSliceOriginIdx&, -+ DstBuffer& dst_buf) const -+ { -+ static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), -+ "wrong! Desc need to known at compile-time"); -+ -+ static_assert(is_known_at_compile_time>::value && -+ is_known_at_compile_time>::value, -+ "wrong! SliceOrigin need to known at compile-time"); -+ -+ static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), -+ "wrong! Buffer need to be StaticBuffer"); -+ -+ // SrcDesc and src_slice_origin_idx are known at compile-time -+ constexpr auto src_desc = remove_cvref_t{}; -+ constexpr auto dst_desc = remove_cvref_t{}; -+ constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); -+ constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); -+ -+ // scalar per access on each dim -+ constexpr auto dst_scalar_per_access = generate_sequence( -+ detail::lambda_scalar_per_access{}, Number{}); -+ -+ constexpr auto dst_scalar_step_in_vector = -+ generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); -+ -+ using SpaceFillingCurve = SpaceFillingCurve>; -+ -+ static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, -+ "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); -+ -+ constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); -+ -+ static_for<0, num_access, 1>{}([&](auto idx_1d) { -+ constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); -+ -+ // copy data from src_buf into dst_vector -+ static_for<0, DstScalarPerVector, 1>{}([&](auto i) { -+ // src_desc error, non constexpr, caused by merge transform -+ constexpr index_t src_offset = src_desc.CalculateOffset( -+ src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); -+ -+ constexpr index_t dst_offset = dst_desc.CalculateOffset( -+ dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); -+ -+ SrcData v_this_row; -+ // int type temp value due to intrinsic requirement -+ int temp = 0; -+ -+ // apply element-wise operation -+ element_op_(v_this_row, src_buf[Number{}]); -+ -+ // apply intra-row permute. -+ if constexpr(IntraRowSwizzlePerm) -+ { -+ temp = __builtin_amdgcn_permlane16( -+ temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); -+ v_this_row = type_convert_sp(temp); -+ } -+ -+ // apply type convert -+ dst_buf(Number{}) = type_convert_sp(v_this_row); -+ }); -+ }); -+ } -+ ElementwiseOperation element_op_{}; -+}; -+ - } // namespace ck -diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp -index 565195f53..9a9ebf559 100644 ---- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp -+++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp -@@ -11,12 +11,17 @@ namespace ck { - - enum struct WmmaInstr - { -+ // gfx11 - wmma_f32_16x16x16_f16 = 0, - wmma_f32_16x16x16_bf16, - wmma_f16_16x16x16_f16, - wmma_bf16_16x16x16_bf16, - wmma_i32_16x16x16_iu8, -- wmma_i32_16x16x16_iu4 -+ wmma_i32_16x16x16_iu4, -+ // gfx12 -+ wmma_f32_16x16x16_f16_gfx12, -+ wmma_f32_16x16x16_bf16_gfx12, -+ wmma_i32_16x16x16_iu8_gfx12, - }; - - /* -@@ -279,6 +284,122 @@ struct wmma_type -+struct wmma_type> -+{ -+ // Absolute fixing property -+ // * Data Pixel -+ static constexpr index_t m_per_wmma = 16; -+ static constexpr index_t n_per_wmma = 16; -+ static constexpr index_t k_per_wmma = 16; -+ // static constexpr index_t src_a_data_size = 2; -+ // static constexpr index_t src_b_data_size = 2; -+ // static constexpr index_t acc_data_size = 4; -+ // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction -+ static constexpr index_t acc_data_size = 4; -+ static constexpr index_t acc_pack_number = 1; -+ static constexpr index_t num_thread_per_subgroups = n_per_wmma; -+ -+ // Wave mode dependent propety -+ static constexpr index_t wave_size = Number{}; -+ // * Fixed in Navi3x, Will be wave mode dependent on Navi4x -+ // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4; -+ // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4; -+ // * num_acc_vgprs_per_wave alone M direction -+ // * num_subgroups alone M direction -+ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; -+ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; -+ -+ template -+ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const -+ { -+ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); -+ if constexpr(wave_size == 32) -+ { -+ intrin_wmma_f32_16x16x16_f16_w32_gfx12::Run(a, b, reg_c); -+ } -+ } -+}; -+ -+template -+struct wmma_type> -+{ -+ // Absolute fixing property -+ static constexpr index_t m_per_wmma = 16; -+ static constexpr index_t n_per_wmma = 16; -+ static constexpr index_t k_per_wmma = 16; -+ // static constexpr index_t src_a_data_size = 2; -+ // static constexpr index_t src_b_data_size = 2; -+ static constexpr index_t acc_data_size = 4; -+ static constexpr index_t acc_pack_number = 1; -+ static constexpr index_t num_thread_per_subgroups = n_per_wmma; -+ -+ // Wave mode dependent propety -+ static constexpr index_t wave_size = Number{}; -+ // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; -+ // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; -+ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; -+ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; -+ -+ template -+ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const -+ { -+ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); -+ if constexpr(wave_size == 32) -+ { -+ intrin_wmma_f32_16x16x16_bf16_w32_gfx12::Run(a, b, reg_c); -+ } -+ } -+}; -+ -+template -+struct wmma_type> -+{ -+ // Absolute fixing property -+ static constexpr index_t m_per_wmma = 16; -+ static constexpr index_t n_per_wmma = 16; -+ static constexpr index_t k_per_wmma = 16; -+ // static constexpr index_t src_a_data_size = 2; -+ // static constexpr index_t src_b_data_size = 2; -+ static constexpr index_t acc_data_size = 4; -+ static constexpr index_t acc_pack_number = 1; -+ static constexpr index_t num_thread_per_subgroups = n_per_wmma; -+ -+ // Wave mode dependent propety -+ static constexpr index_t wave_size = Number{}; -+ // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; -+ // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; -+ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; -+ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; -+ -+ template -+ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const -+ { -+ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); -+ if constexpr(wave_size == 32) -+ { -+ intrin_wmma_i32_16x16x16_iu8_w32_gfx12::Run( -+ a, b, reg_c); -+ } -+ } -+}; -+ - template - static constexpr auto GetWmma() - { -+#ifdef __gfx12__ -+ return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; -+#else - return WmmaInstr::wmma_f32_16x16x16_f16; -+#endif - } - - template <> - static constexpr auto GetWmma() - { -+#ifdef __gfx12__ -+ return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; -+#else - return WmmaInstr::wmma_f32_16x16x16_bf16; -+#endif - } - - template <> -@@ -320,8 +449,13 @@ struct WmmaSelector - template <> - static constexpr auto GetWmma() - { -+#ifdef __gfx12__ -+ return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; -+#else - return WmmaInstr::wmma_i32_16x16x16_iu8; -+#endif - } -+ - #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - template <> - static constexpr auto GetWmma() -@@ -502,6 +636,9 @@ struct WmmaGemm - - __device__ static auto GetSubGroupId() - { -+ static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups == -+ wmma_instr.wave_size, -+ ""); - return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups; - } - -@@ -516,12 +653,20 @@ struct WmmaGemm - - __host__ __device__ static auto CalculateAThreadOriginDataIndex() - { -+#ifdef __gfx12__ -+ return GetLaneIdUnderSubGroup(); -+#else - return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); -+#endif - } - - __host__ __device__ static auto CalculateBThreadOriginDataIndex() - { -+#ifdef __gfx12__ -+ return GetLaneIdUnderSubGroup(); -+#else - return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); -+#endif - } - - __device__ static CIndex GetBeginOfThreadBlk() -diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp -index 1bb0140f3..322a0f94b 100644 ---- a/include/ck/utility/amd_wmma.hpp -+++ b/include/ck/utility/amd_wmma.hpp -@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> - } - }; - -+// gfx12 -+/********************************WAVE32 MODE***********************************************/ -+ -+#if defined(__gfx1200__) || defined(__gfx1201__) -+#define __gfx12__ -+#endif -+ -+// src: fp16, dst: fp32 -+template -+struct intrin_wmma_f32_16x16x16_f16_w32_gfx12; -+ -+template <> -+struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16> -+{ -+ template -+ __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) -+ { -+ // * Inline assembly need to elimate the duplicated data load, compiler won't help you -+ // delete them. -+ // amd_assembly_wmma_f32_16x16x16_f16_w32( -+ // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); -+#if defined(__gfx12__) -+ reg_c.template AsType()(Number<0>{}) = -+ __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( -+ reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); -+#else -+ ignore = reg_a; -+ ignore = reg_b; -+ ignore = reg_c; -+#endif -+ } -+}; -+ -+// src: bf16, dst: fp32 -+template -+struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12; -+ -+template <> -+struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16> -+{ -+ template -+ __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) -+ { -+#if defined(__gfx12__) -+ reg_c.template AsType()(Number<0>{}) = -+ __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12( -+ reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); -+#else -+ ignore = reg_a; -+ ignore = reg_b; -+ ignore = reg_c; -+#endif -+ } -+}; -+ -+// src: iu8, dst: i32 -+template -+struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12; -+ -+template -+struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> -+{ -+ template -+ __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) -+ { -+#if defined(__gfx12__) -+ reg_c.template AsType()(Number<0>{}) = -+ __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( -+ neg_a, -+ bit_cast(reg_a), -+ neg_b, -+ bit_cast(reg_b), -+ reg_c.template AsType()[Number<0>{}], -+ clamp); -+#else -+ ignore = reg_a; -+ ignore = reg_b; -+ ignore = reg_c; -+#endif -+ } -+}; -+ - } // namespace ck - #endif -diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp -index 93a1edefb..4df14c621 100644 ---- a/include/ck/utility/data_type.hpp -+++ b/include/ck/utility/data_type.hpp -@@ -203,7 +203,7 @@ struct vector_type - } - }; - --int static err = 0; -+__device__ int static err = 0; - template - struct vector_type - { -diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp -index 4fe5e3950..d6b6eac26 100644 ---- a/include/ck/utility/synchronization.hpp -+++ b/include/ck/utility/synchronization.hpp -@@ -10,12 +10,20 @@ namespace ck { - __device__ void block_sync_lds() - { - #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM -+#ifdef __gfx12__ -+ asm volatile("\ -+ s_wait_dscnt 0x0 \n \ -+ s_barrier_signal -1 \n \ -+ s_barrier_wait -1 \ -+ " ::); -+#else - // asm volatile("\ - // s_waitcnt lgkmcnt(0) \n \ - // s_barrier \ - // " ::); - __builtin_amdgcn_s_waitcnt(0xc07f); - __builtin_amdgcn_s_barrier(); -+#endif - #else - __syncthreads(); - #endif -@@ -23,11 +31,20 @@ __device__ void block_sync_lds() - - __device__ void block_sync_lds_direct_load() - { -+#ifdef __gfx12__ -+ asm volatile("\ -+ s_wait_vmcnt 0x0 \n \ -+ s_wait_dscnt 0x0 \n \ -+ s_barrier_signal -1 \n \ -+ s_barrier_wait -1 \ -+ " ::); -+#else - asm volatile("\ - s_waitcnt vmcnt(0) \n \ - s_waitcnt lgkmcnt(0) \n \ - s_barrier \ - " ::); -+#endif - } - - __device__ void s_nop() -diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp -index 601aad19b..9dc2b072a 100644 ---- a/include/ck_tile/core/config.hpp -+++ b/include/ck_tile/core/config.hpp -@@ -17,6 +17,9 @@ - #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) - #define __gfx11__ - #endif -+#if defined(__gfx1200__) || defined(__gfx1201__) -+#define __gfx12__ -+#endif - - #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS - #include "hip/hip_runtime.h" -@@ -155,7 +158,7 @@ - #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 - #elif defined(__gfx103__) // for GPU code - #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 --#elif defined(__gfx11__) // for GPU code -+#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code - #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 - #endif - -diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt -index 8c5f36d2e..89c9d6dc6 100644 ---- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt -+++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt -@@ -52,7 +52,7 @@ function(add_instance_library INSTANCE_NAME) - endforeach() - # Do not build WMMA instances if gfx11 targets are not on the target list - foreach(source IN LISTS ARGN) -- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") -+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") - message("removing wmma instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() -@@ -149,7 +149,7 @@ FOREACH(subdir_path ${dir_list}) - message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") - set(add_inst 0) - endif() -- if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11")) -+ if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12")) - message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") - set(add_inst 0) - endif() -@@ -157,11 +157,11 @@ FOREACH(subdir_path ${dir_list}) - message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") - set(add_inst 0) - endif() -- if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9")) -+ if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9")) - message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") - set(add_inst 0) - endif() -- if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) -+ if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) - message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") - set(add_inst 0) - endif() -diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt -index 1cfcbfff6..a9557a9b9 100644 ---- a/profiler/src/CMakeLists.txt -+++ b/profiler/src/CMakeLists.txt -@@ -58,7 +58,7 @@ if(GPU_TARGETS MATCHES "gfx9") - - endif() - --if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9") -+if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9") - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) - endif() -@@ -133,7 +133,7 @@ if(GPU_TARGETS MATCHES "gfx9") - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) - endif() - --if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") -+if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) - endif() -diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt -index 25c63ac7f..2a7c52b58 100644 ---- a/test/CMakeLists.txt -+++ b/test/CMakeLists.txt -@@ -53,7 +53,7 @@ function(add_test_executable TEST_NAME) - endif() - endforeach() - foreach(source IN LISTS ARGN) -- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") -+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") - message("removing wmma test ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() -@@ -118,7 +118,7 @@ function(add_gtest_executable TEST_NAME) - endif() - endforeach() - foreach(source IN LISTS ARGN) -- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") -+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") - message("removing wmma test ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() -diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp -index 1c8082645..21f49ec0f 100644 ---- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp -+++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp -@@ -55,7 +55,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test - } - } - -- if(ck::is_gfx11_supported()) -+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) - { - // on gfx11 only support for 3d is implemented - if constexpr(NDimSpatial{} != 3) -diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp -index 49782bce6..d9ec94771 100644 ---- a/test/wmma_op/wmma_op_util.hpp -+++ b/test/wmma_op/wmma_op_util.hpp -@@ -140,10 +140,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) - p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele]; - } - -+#ifdef __gfx12__ -+ asm volatile("\ -+ s_wait_dscnt 0x0 \n \ -+ s_barrier_signal -1 \n \ -+ s_barrier_wait -1 \ -+ " ::); -+#else - asm volatile("\ - s_waitcnt lgkmcnt(0) \n \ - s_barrier \ - " ::); -+#endif - - for(int ele = 0; ele < 16; ++ele) - { -@@ -155,10 +163,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) - a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8]; - } - -+#ifdef __gfx12__ -+ asm volatile("\ -+ s_wait_dscnt 0x0 \n \ -+ s_barrier_signal -1 \n \ -+ s_barrier_wait -1 \ -+ " ::); -+#else - asm volatile("\ - s_waitcnt lgkmcnt(0) \n \ - s_barrier \ - " ::); -+#endif - - // sync threads, similar to mma_sync - // __syncthreads(); diff --git a/cmake/patches/composable_kernel/Fix_Clang_Build.patch b/cmake/patches/composable_kernel/Fix_Clang_Build.patch deleted file mode 100644 index d63da63445fde..0000000000000 --- a/cmake/patches/composable_kernel/Fix_Clang_Build.patch +++ /dev/null @@ -1,238 +0,0 @@ -diff --git a/CMakeLists.txt b/CMakeLists.txt -index c23746e7f..bc326c8b5 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -23,10 +23,10 @@ endif() - - set(version 1.1.0) - # Check support for CUDA/HIP in Cmake --project(composable_kernel VERSION ${version} LANGUAGES CXX) -+project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) - include(CTest) - --find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) -+find_package(Python3 COMPONENTS Interpreter REQUIRED) - - list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") - -@@ -227,27 +227,6 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) - set(CMAKE_CXX_EXTENSIONS OFF) - message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") - --## OpenMP --if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") -- # workaround issue hipcc in rocm3.5 cannot find openmp -- set(OpenMP_CXX "${CMAKE_CXX_COMPILER}") -- set(OpenMP_CXX_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument") -- set(OpenMP_CXX_LIB_NAMES "libomp" "libgomp" "libiomp5") -- set(OpenMP_libomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) -- set(OpenMP_libgomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) -- set(OpenMP_libiomp5_LIBRARY ${OpenMP_CXX_LIB_NAMES}) --else() -- find_package(OpenMP REQUIRED) --endif() -- --message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") --message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") --message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") --message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") -- --link_libraries(${OpenMP_gomp_LIBRARY}) --link_libraries(${OpenMP_pthread_LIBRARY}) -- - ## HIP - find_package(HIP REQUIRED) - # Override HIP version in config.h, if necessary. -@@ -269,12 +248,6 @@ if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) - message(STATUS "CK_HIP_VERSION_PATCH overridden with ${CK_OVERRIDE_HIP_VERSION_PATCH}") - endif() - message(STATUS "Build with HIP ${HIP_VERSION}") --link_libraries(hip::device) --if(CK_hip_VERSION VERSION_GREATER_EQUAL 6.0.23494) -- add_compile_definitions(__HIP_PLATFORM_AMD__=1) --else() -- add_compile_definitions(__HIP_PLATFORM_HCC__=1) --endif() - - ## tidy - include(EnableCompilerWarnings) -@@ -541,11 +514,3 @@ rocm_install(FILES - - set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") - set(CPACK_RPM_PACKAGE_LICENSE "MIT") -- --rocm_create_package( -- NAME composablekernel -- DESCRIPTION "High Performance Composable Kernel for AMD GPUs" -- MAINTAINER "MIOpen Kernels Dev Team " -- LDCONFIG -- HEADER_ONLY --) -diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py -index 51fecd07b..5ed371995 100644 ---- a/example/ck_tile/01_fmha/generate.py -+++ b/example/ck_tile/01_fmha/generate.py -@@ -566,7 +566,7 @@ def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], recei - def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: - assert output_file is not None - file_path = Path(output_file) -- with file_path.open('a') as f: -+ with file_path.open('w') as f: - _, kernels = get_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") -diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp -index c0894f1d7..559481fee 100644 ---- a/include/ck/host_utility/hip_check_error.hpp -+++ b/include/ck/host_utility/hip_check_error.hpp -@@ -6,19 +6,7 @@ - #include - #include - --// To be removed, which really does not tell the location of failed HIP functional call --inline void hip_check_error(hipError_t x) --{ -- if(x != hipSuccess) -- { -- std::ostringstream ss; -- ss << "HIP runtime error: " << hipGetErrorString(x) << ". " -- << "hip_check_error.hpp" -- << ": " << __LINE__ << "in function: " << __func__; -- throw std::runtime_error(ss.str()); -- } --} -- -+#ifndef HIP_CHECK_ERROR - #define HIP_CHECK_ERROR(retval_or_funcall) \ - do \ - { \ -@@ -32,3 +20,9 @@ inline void hip_check_error(hipError_t x) - throw std::runtime_error(ostr.str()); \ - } \ - } while(0) -+#endif -+ -+#ifndef hip_check_error -+#define hip_check_error HIP_CHECK_ERROR -+#endif -+ -diff --git a/include/ck_tile/core/utility/transpose_vectors.hpp b/include/ck_tile/core/utility/transpose_vectors.hpp -index a164c3f94..293ead89a 100644 ---- a/include/ck_tile/core/utility/transpose_vectors.hpp -+++ b/include/ck_tile/core/utility/transpose_vectors.hpp -@@ -11,6 +11,9 @@ - - namespace ck_tile { - -+template -+constexpr bool always_false = false; -+ - // S: scalar type (or it can be non-scalar type) - // NX: # of vector before transpose - // NY: # of vector after transpose -@@ -117,9 +120,11 @@ struct transpose_vectors - } - else - { -- static_assert(false, "not implemented"); -+ static_assert(always_false, number>, "not implemented"); - } - } - }; - -+ - } // namespace ck_tile -+ -diff --git a/include/ck_tile/host/hip_check_error.hpp b/include/ck_tile/host/hip_check_error.hpp -index 3acdb4d87..cc26e184f 100644 ---- a/include/ck_tile/host/hip_check_error.hpp -+++ b/include/ck_tile/host/hip_check_error.hpp -@@ -8,20 +8,7 @@ - #include - #include - --namespace ck_tile { --// To be removed, which really does not tell the location of failed HIP functional call --CK_TILE_HOST void hip_check_error(hipError_t x) --{ -- if(x != hipSuccess) -- { -- std::ostringstream ss; -- ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__ -- << "in function: " << __func__; -- throw std::runtime_error(ss.str()); -- } --} --} // namespace ck_tile -- -+#ifndef HIP_CHECK_ERROR - #define HIP_CHECK_ERROR(retval_or_funcall) \ - do \ - { \ -@@ -34,3 +21,9 @@ CK_TILE_HOST void hip_check_error(hipError_t x) - throw std::runtime_error(ostr.str()); \ - } \ - } while(0) -+#endif -+ -+#ifndef hip_check_error -+#define hip_check_error HIP_CHECK_ERROR -+#endif -+ -diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt -index c035e7e56..8c5f36d2e 100644 ---- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt -+++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt -@@ -59,8 +59,14 @@ function(add_instance_library INSTANCE_NAME) - endforeach() - #only continue if there are some source files left on the list - if(ARGN) -+ set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) - add_library(${INSTANCE_NAME} OBJECT ${ARGN}) -+ # Always disable debug symbol and C debug assert due to -+ # - Linker error: ... relocation truncated to fit ..., caused by object files to be linked are too huge. -+ # - https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/622 -+ target_compile_options(${INSTANCE_NAME} PRIVATE -g0 -DNDEBUG) - target_compile_features(${INSTANCE_NAME} PUBLIC) -+ target_compile_definitions(${INSTANCE_NAME} PRIVATE "__HIP_PLATFORM_AMD__=1" "__HIP_PLATFORM_HCC__=1") - set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) - clang_tidy_check(${INSTANCE_NAME}) - set(result 0) ---- ./include/ck/utility/amd_buffer_addressing.hpp 2024-09-05 10:12:33.343091000 +0800 -+++ ./include/ck/utility/amd_buffer_addressing_new.hpp 2024-09-05 10:12:20.276686000 +0800 -@@ -991,7 +991,8 @@ - asm volatile("s_mov_b32 m0, %0; \n\t" - "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), - "v"(global_offset_bytes), -- "s"(src_resource)); -+ "s"(src_resource) -+ : "memory"); - #else - // LDS pointer must be attributed with the LDS address space. - __attribute__((address_space(3))) uint32_t* lds_ptr = ---- ./include/ck_tile/core/arch/amd_buffer_addressing.hpp 2024-09-05 10:18:28.884031000 +0800 -+++ ./include/ck_tile/core/arch/amd_buffer_addressing_new.hpp 2024-09-05 10:17:29.434931000 +0800 -@@ -26,7 +26,12 @@ - CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) - { - buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; -- return __builtin_bit_cast(int32x4_t, res); -+ int32x4_t r = __builtin_bit_cast(int32x4_t, res); -+ r.x = __builtin_amdgcn_readfirstlane(r.x); -+ r.y = __builtin_amdgcn_readfirstlane(r.y); -+ r.z = __builtin_amdgcn_readfirstlane(r.z); -+ r.w = __builtin_amdgcn_readfirstlane(r.w); -+ return r; - } - - // TODO: glc/slc/... -@@ -2016,7 +2021,8 @@ - asm volatile("s_mov_b32 m0, %0; \n\t" - "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), - "v"(global_offset_bytes), -- "s"(src_resource)); -+ "s"(src_resource) -+ : "memory"); - #else - // LDS pointer must be attributed with the LDS address space. - __attribute__((address_space(3))) uint32_t* lds_ptr = diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index 73613541f8362..3779a72d4de69 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -164,10 +164,6 @@ public void TestSessionOptions() opt.AppendExecutionProvider_OpenVINO(); #endif -#if USE_ROCM - opt.AppendExecutionProvider_ROCm(0); -#endif - #if USE_TENSORRT opt.AppendExecutionProvider_Tensorrt(0); #endif @@ -1764,33 +1760,6 @@ void TestCUDAAllocatorInternal(InferenceSession session) } #endif -#if USE_ROCM - void TestROCMAllocatorInternal(InferenceSession session) - { - int device_id = 0; - using (var info_rocm = new OrtMemoryInfo(OrtMemoryInfo.allocatorHIP, OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default)) - { - Assert.Equal("Hip", info_rocm.Name); - Assert.Equal(device_id, info_rocm.Id); - Assert.Equal(OrtAllocatorType.ArenaAllocator, info_rocm.GetAllocatorType()); - Assert.Equal(OrtMemType.Default, info_rocm.GetMemoryType()); - - using (var allocator = new OrtAllocator(session, info_rocm)) - { - var alloc_info = allocator.Info; - Assert.True(info_rocm.Equals(alloc_info)); - - uint size = 1024; - OrtMemoryAllocation chunk = allocator.Allocate(size); - Assert.Equal(chunk.Size, size); - Assert.True(chunk.Info.Equals(alloc_info)); - chunk.Dispose(); - alloc_info.Dispose(); - } - } - } -#endif - [Fact(DisplayName = "TestAllocator")] private void TestAllocator() { @@ -1801,21 +1770,12 @@ private void TestAllocator() #if USE_CUDA options.AppendExecutionProvider_CUDA(0); #endif - -#if USE_ROCM - options.AppendExecutionProvider_ROCm(0); -#endif - using (var session = new InferenceSession(model, options)) { TestCPUAllocatorInternal(session); #if USE_CUDA TestCUDAAllocatorInternal(session); #endif -#if USE_ROCM - TestROCMAllocatorInternal(session); -#endif - } } } @@ -1942,15 +1902,6 @@ internal static Tuple, float[]> Op { option.AppendExecutionProvider_CPU(1); } -#elif USE_ROCM - using (var option = (deviceId.HasValue) ? - SessionOptions.MakeSessionOptionWithRocmProvider(deviceId.Value) : - new SessionOptions()) - { - if(!deviceId.HasValue) - { - option.AppendExecutionProvider_CPU(1); - } #else using (var option = new SessionOptions()) { diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs index ae4fb0cf164cd..94f8e927c1331 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs @@ -60,9 +60,6 @@ public void GetAvailableProviders() #if USE_CUDA Assert.True(Array.Exists(providers, provider => provider == "CUDAExecutionProvider")); -#endif -#if USE_ROCM - Assert.True(Array.Exists(providers, provider => provider == "ROCMExecutionProvider")); #endif } } @@ -493,4 +490,3 @@ void TestCopyTensors() } } } - diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index 89dbce05326b5..f0d1313783643 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -1531,7 +1531,6 @@ private void TestInferenceWithLoraAdapterFromArray() // TestGpu() will test // - the CUDA EP on CUDA enabled builds // - the DML EP on DML enabled builds - // - the ROCm EP on ROCm enabled builds [GpuFact(DisplayName = "TestGpu")] private void TestGpu() { @@ -1575,9 +1574,6 @@ private void VerifyNativeMethodsExist() #if USE_CUDA ,"OrtSessionOptionsAppendExecutionProvider_CUDA" #endif -#if USE_ROCM - ,"OrtSessionOptionsAppendExecutionProvider_ROCM" -#endif #if USE_DML ,"OrtSessionOptionsAppendExecutionProvider_DML" #endif diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx index 876a07e4ffaf6..3c52b7574ccf6 100644 --- a/dockerfiles/Dockerfile.migraphx +++ b/dockerfiles/Dockerfile.migraphx @@ -5,22 +5,17 @@ # Dockerfile to run ONNXRuntime with MIGraphX integration #-------------------------------------------------------------------------- -FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 +FROM rocm/pytorch:rocm7.1_ubuntu24.04_py3.12_pytorch_release_2.9.1 -ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime +ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main -ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} - -RUN apt-get update &&\ - apt-get install -y migraphx - WORKDIR /code # Prepare onnxruntime repository & build onnxruntime RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ cd onnxruntime && pip install --upgrade pip &&\ - /bin/sh ./build.sh --allow_running_as_root --cmake_extra_defines ONNXRUNTIME_VERSION=`cat ./VERSION_NUMBER` --config Release --parallel \ - --skip_tests --build_wheel --use_rocm --rocm_version=${ROCM_VERSION} --rocm_home /opt/rocm --use_migraphx &&\ + /bin/sh ./build.sh --allow_running_as_root --cmake_extra_defines ONNXRUNTIME_VERSION=`cat ./VERSION_NUMBER` \ + --config Release --parallel --skip_tests --build_wheel --use_migraphx &&\ pip install /code/onnxruntime/build/Linux/Release/dist/*.whl diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm deleted file mode 100644 index aca8c3feaff71..0000000000000 --- a/dockerfiles/Dockerfile.rocm +++ /dev/null @@ -1,24 +0,0 @@ -# -------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------- -# Dockerfile to run ONNXRuntime with ROCm integration -#-------------------------------------------------------------------------- - -FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 - -ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime -ARG ONNXRUNTIME_BRANCH=main - -WORKDIR /code - -ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} - -# Prepare onnxruntime repository & build onnxruntime -RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ - /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ - cd onnxruntime &&\ - /bin/sh ./build.sh --allow_running_as_root --config Release --build_wheel --update --build --parallel --cmake_extra_defines\ - ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm &&\ - pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ - cd .. diff --git a/dockerfiles/README.md b/dockerfiles/README.md index 4c69098103edd..88c542b63ccd2 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -1,9 +1,8 @@ # Dockerfiles **Execution Providers** - CPU: [Dockerfile](Dockerfile.source), [Instructions](#cpu) -- CUDA/cuDNN: [Dockerfile](Dockerfile.cuda), [Instructions](#cuda) +- CUDA: [Dockerfile](Dockerfile.cuda), [Instructions](#cuda) - MIGraphX: [Dockerfile](Dockerfile.migraphx), [Instructions](#migraphx) -- ROCm: [Dockerfile](Dockerfile.rocm), [Instructions](#rocm) - OpenVINO: [Dockerfile](Dockerfile.openvino), [Instructions](#openvino) - TensorRT: [Dockerfile](Dockerfile.tensorrt), [Instructions](#tensorrt) - VitisAI: [Dockerfile](Dockerfile.vitisai) @@ -304,17 +303,3 @@ Note: When running the container you built in Docker, please either use 'nvidia- ``` docker run -it --device=/dev/kfd --device=/dev/dri --group-add video onnxruntime-migraphx ``` - - ## ROCm -**Ubuntu 22.04, ROCm6.2.3** - -1. Build the docker image from the Dockerfile in this repository. - ``` - docker build -t onnxruntime-rocm -f Dockerfile.rocm . - ``` - -2. Run the Docker image - - ``` - docker run -it --device=/dev/kfd --device=/dev/dri --group-add video onnxruntime-rocm - ``` diff --git a/dockerfiles/scripts/install_rocm_deps.sh b/dockerfiles/scripts/install_rocm_deps.sh deleted file mode 100644 index fd445be87479b..0000000000000 --- a/dockerfiles/scripts/install_rocm_deps.sh +++ /dev/null @@ -1,84 +0,0 @@ -#!/bin/bash -prefix=/opt/rocm -DEBIAN_FRONTEND=noninteractive -apt-get update && apt-get install -y --no-install-recommends \ - wget \ - zip \ - ca-certificates \ - build-essential \ - curl \ - libcurl4-openssl-dev \ - libssl-dev \ - python3-dev - -# rocm-cmake -rocm_cmake_version=4.5.2 -wget --quiet https://github.com/RadeonOpenCompute/rocm-cmake/archive/refs/tags/rocm-${rocm_cmake_version}.tar.gz -tar -xzvf rocm-${rocm_cmake_version}.tar.gz -rm rocm-${rocm_cmake_version}.tar.gz -cd rocm-cmake-rocm-${rocm_cmake_version} -mkdir build -cd build -cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rocm-cmake-rocm-${rocm_cmake_version} - -# rccl -rccl_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/rccl/archive/refs/tags/rocm-${rccl_version}.tar.gz -tar -xzvf rocm-${rccl_version}.tar.gz -rm rocm-${rccl_version}.tar.gz -cd rccl-rocm-${rccl_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rccl-rocm-${rccl_version} - -#rocrand -rocrand_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/rocRAND/archive/refs/tags/rocm-${rocrand_version}.tar.gz -tar -xzvf rocm-${rocrand_version}.tar.gz -rm rocm-${rocrand_version}.tar.gz -cd rocRAND-rocm-${rocrand_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rocRAND-rocm-${rocrand_version} - -#hipcub -hipcub_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/hipCUB/archive/refs/tags/rocm-${hipcub_version}.tar.gz -tar -xzvf rocm-${hipcub_version}.tar.gz -rm rocm-${hipcub_version}.tar.gz -cd hipCUB-rocm-${hipcub_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make package -make install -cd ../.. -rm -rf hipCUB-rocm-${hipcub_version} - -#rocprim -rocprim_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/rocPRIM/archive/refs/tags/rocm-${rocprim_version}.tar.gz -tar -xzvf rocm-${rocprim_version}.tar.gz -rm rocm-${rocprim_version}.tar.gz -cd rocPRIM-rocm-${rocprim_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rocPRIM-rocm-${rocprim_version} - diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index e59a803d97629..5f391432ce503 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -190,13 +190,6 @@ KernelCreateInfo BuildKernelCreateInfo(); } // namespace js } // namespace contrib -namespace contrib { -namespace rocm { -template -KernelCreateInfo BuildKernelCreateInfo(); -} // namespace rocm -} // namespace contrib - namespace contrib { namespace snpe { template diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index 935be9c3f00c7..c85b01210fc3b 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -56,7 +56,7 @@ struct OrtDevice { enum VendorIds : VendorId { // No vendor ID. Valid for DeviceType::CPU + MemType::DEFAULT or for generic allocators like WebGPU. NONE = 0x0000, - AMD = 0x1002, // ROCm, MIGraphX EPs + AMD = 0x1002, // MIGraphX EP NVIDIA = 0x10DE, // CUDA/TensorRT ARM = 0x13B5, // ARM GPU EP MICROSOFT = 0x1414, // DML EP diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index d3f1182909b5c..fa34ef75f2eb5 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -44,7 +44,6 @@ constexpr const char* kDmlExecutionProvider = "DmlExecutionProvider"; constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider"; constexpr const char* kAclExecutionProvider = "ACLExecutionProvider"; constexpr const char* kArmNNExecutionProvider = "ArmNNExecutionProvider"; -constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider"; constexpr const char* kCoreMLExecutionProvider = "CoreMLExecutionProvider"; constexpr const char* kJsExecutionProvider = "JsExecutionProvider"; constexpr const char* kSnpeExecutionProvider = "SNPEExecutionProvider"; diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 11ca73790ea79..ef5cd49334133 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1454,12 +1454,16 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return Resolve(default_options); } + /// + /// This function converts all the graph TensorProto initializers into OrtValues + /// and creates a in-memory external data reference for each OrtValue. + /// + /// + Status ConvertInitializersIntoOrtValues(); + /** - * @brief Converts a subset of graph TensorProto initializers into OrtValues and updates the graph proto. - * - * This function converts specified TensorProto initializers in the graph into OrtValues and - * creates in-memory external data references for each OrtValue. It then updates the provided - * GraphProto with the modified initializers. + * @brief This function examines the specified initializers in the graph and converts them inline + * if any has external data in memory. * * @param iterators Span of iterators pointing to the initializers and the order that should be processed * @param output_graph_proto The GraphProto to be updated with the modified initializers @@ -1633,17 +1637,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi /// Status indicating success or failure Status ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const; - /// - /// This function replaces all of the initializers within output_graph_proto - /// from this Graph instance. All in memory initializers are regenerated and inlined. - /// This is necessary even if the graph_proto_ is already up to date because initializers() may - /// contain obsolete initializers that are no longer in use due to optimizations and contain obsolete - /// references to OrtValues that may no longer be around (since we like appending rather than replacing). - /// - /// Destination GraphProto to receive the updated initializers. - /// Status indicating success or failure. - Status RegenerateInitializersAndReplaceInMemory(ONNX_NAMESPACE::GraphProto& output_graph_proto) const; - /// /// This function traverses the graph bottom up and externalizes /// constant initializers along with their pre-packed blobs from different @@ -1753,6 +1746,15 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::vector& output_types, const Graph::ResolveOptions& options); + // If ONNX operator's PartialDataPropagationFunction() infers concrete shape values in the output + // save them to the output NodeArg as a TensorShapeProto or a scalar value so that downstream (consumer) nodes + // can use them later for their TypeAndShapeInferenceFunction() and PartialDataPropagationFunction(). + common::Status SaveShapeValuesFromDataPropagation(const Node& node, NodeArg& output_def, + const ONNX_NAMESPACE::TypeProto& propagated_value_as_type_proto) const; + + // Remove intermediate inferred shape values stored in all NodeArgs to reduce memory usage. + common::Status CleanUpShapeValuesFromDataPropagation(); + // Apply type-inference and type-checking to all inputs and initializers: common::Status TypeCheckInputsAndInitializers(); diff --git a/include/onnxruntime/core/graph/node_arg.h b/include/onnxruntime/core/graph/node_arg.h index 0ddf1a2b9d3de..4a18d7617ac13 100644 --- a/include/onnxruntime/core/graph/node_arg.h +++ b/include/onnxruntime/core/graph/node_arg.h @@ -9,6 +9,8 @@ #include "core/common/status.h" #include "core/common/logging/logging.h" +#include + namespace onnxruntime { // Node argument definition, for both input and output, @@ -107,6 +109,18 @@ class NodeArg { /** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */ const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } + /** Gets the inferred shape values as a TensorShapeProto. */ + const std::optional& GetInferredShapeValues() const noexcept { return inferred_shape_values_; } + + /** Gets mutable inferred shape values as a TensorShapeProto. */ + std::optional& GetMutableInferredShapeValues() noexcept { return inferred_shape_values_; } + + /** Gets the inferred shape scalar value */ + const std::optional GetInferredShapeScalarValue() const noexcept { return inferred_scalar_value_; } + + /** Sets the inferred shape scalar value */ + void SetInferredShapeScalarValue(int64_t value) noexcept { inferred_scalar_value_ = value; } + /** Gets a flag indicating whether this NodeArg exists or not. Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */ bool Exists() const noexcept; @@ -128,6 +142,24 @@ class NodeArg { // Node arg name, type and shape. NodeArgInfo node_arg_info_; + // This variable stores the actual tensor data of the shape as a TensorShapeProto after executing + // the ONNX operator's PartialDataPropagationFunction(). It's used for shape inference purpose. + // + // Calling an operator's TypeAndShapeInferenceFunction() alone is sometimes insufficient + // for complete shape inference. For example, the Shape operator's TypeAndShapeInferenceFunction() + // only provides the output's rank which is 1 but not its actual shape values. + // + // The PartialDataPropagationFunction(), defined in the ONNX operator schema, must also + // be executed to obtain the concrete shape values output, allowing accurate propagation + // of shape information throughout the graph. If the concrete shape values output is not + // computed, nothing is stored here that's why this is optional. + std::optional inferred_shape_values_; + + // This variable stores the actual scalar value. + // It is also used for shape inference and data propagation to ensure consistent shape and + // value information throughout the graph. + std::optional inferred_scalar_value_; + // Flag indicates whether <*this> node arg exists or not. bool exists_; }; diff --git a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h index 026fc3b2dc0a0..c9cd2a00ec167 100644 --- a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h +++ b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h @@ -7,8 +7,10 @@ * - `kDeviceId`: Specifies the GPU device ID to use. * - `kHasUserComputeStream`: Indicates whether a user-provided compute stream is used. * - `kUserComputeStream`: Specifies the user-provided compute stream. + * - `kUserAuxStreamArray`: Specifies the user-provided aux stream. * - `kMaxWorkspaceSize`: Sets the maximum workspace size for GPU memory allocation. * - 'kMaxSharedMemSize': Sets the maximum amount of shared memory that TensorRT kernels are allowed to use + * - `kLengthAuxStreamArray`: Specifies the length/size of the auxiliary streams array (kUserAuxStreamArray). Also sets the maximum number of auxiliary streams for TensorRT execution. * - `kDumpSubgraphs`: Enables or disables dumping of subgraphs for debugging. * - `kDetailedBuildLog`: Enables or disables detailed build logs for debugging. * - `kProfilesMinShapes`: Specifies the minimum shapes for profiling. @@ -24,8 +26,10 @@ namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; constexpr const char* kUserComputeStream = "user_compute_stream"; +constexpr const char* kUserAuxStreamArray = "user_aux_stream_array"; constexpr const char* kMaxWorkspaceSize = "nv_max_workspace_size"; constexpr const char* kMaxSharedMemSize = "nv_max_shared_mem_size"; +constexpr const char* kLengthAuxStreamArray = "nv_length_aux_stream_array"; constexpr const char* kDumpSubgraphs = "nv_dump_subgraphs"; constexpr const char* kDetailedBuildLog = "nv_detailed_build_log"; constexpr const char* kProfilesMinShapes = "nv_profile_min_shapes"; diff --git a/include/onnxruntime/core/providers/resource.h b/include/onnxruntime/core/providers/resource.h index bd123e1cd41c2..8f9451ad11a4e 100644 --- a/include/onnxruntime/core/providers/resource.h +++ b/include/onnxruntime/core/providers/resource.h @@ -7,8 +7,8 @@ enum ResourceOffset { cpu_resource_offset = 0, cuda_resource_offset = 10000, dml_resource_offset = 20000, - rocm_resource_offset = 30000, + migraphx_resource_offset = 30000, // offsets for other ort eps custom_ep_resource_offset = 10000000, // offsets for customized eps -}; \ No newline at end of file +}; diff --git a/include/onnxruntime/core/providers/rocm/rocm_context.h b/include/onnxruntime/core/providers/rocm/rocm_context.h deleted file mode 100644 index aad1736217129..0000000000000 --- a/include/onnxruntime/core/providers/rocm/rocm_context.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define ORT_ROCM_CTX - -#include "rocm_resource.h" -#include "core/providers/custom_op_context.h" -#include -#include -#include - -namespace Ort { - -namespace Custom { - -struct RocmContext : public CustomOpContext { - hipStream_t hip_stream = {}; - miopenHandle_t miopen_handle = {}; - hipblasHandle_t blas_handle = {}; - - void Init(const OrtKernelContext& kernel_ctx) { - const auto& ort_api = Ort::GetApi(); - void* resource = {}; - OrtStatus* status = nullptr; - - status = ort_api.KernelContext_GetResource( - &kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::hip_stream_t, &resource); - if (status) { - ORT_CXX_API_THROW("failed to fetch hip stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - hip_stream = reinterpret_cast(resource); - - resource = {}; - status = ort_api.KernelContext_GetResource( - &kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::miopen_handle_t, &resource); - if (status) { - ORT_CXX_API_THROW("failed to fetch miopen handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - miopen_handle = reinterpret_cast(resource); - - resource = {}; - status = ort_api.KernelContext_GetResource( - &kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::hipblas_handle_t, &resource); - if (status) { - ORT_CXX_API_THROW("failed to fetch hipblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - blas_handle = reinterpret_cast(resource); - } -}; - -} // namespace Custom -} // namespace Ort diff --git a/include/onnxruntime/core/providers/rocm/rocm_resource.h b/include/onnxruntime/core/providers/rocm/rocm_resource.h deleted file mode 100644 index db032b48714c3..0000000000000 --- a/include/onnxruntime/core/providers/rocm/rocm_resource.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/resource.h" - -#define ORT_ROCM_RESOURCE_VERSION 1 - -enum RocmResource : int { - hip_stream_t = rocm_resource_offset, - miopen_handle_t, - hipblas_handle_t, - deferred_cpu_allocator_t, - // below are rocm ep options - device_id_t, // 10004 - arena_extend_strategy_t -}; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 22708bbf06a3d..bc75aabc7e229 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -644,6 +644,9 @@ ORT_DEFINE_RELEASE(ValueInfo); ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi); +ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDef, GetEpApi); +ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi); +ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelRegistry, GetEpApi); // This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type, // but the struct has V2 in its name to indicate that it is the second version of the options. @@ -1441,6 +1444,12 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options = {}); + + ///< Wraps OrtApi::AddFreeDimensionOverride + SessionOptionsImpl& AddFreeDimensionOverride(const char* dim_denotation, int64_t dim_value); + + ///< Wraps OrtApi::AddFreeDimensionOverrideByName + SessionOptionsImpl& AddFreeDimensionOverrideByName(const char* dim_name, int64_t dim_value); }; } // namespace detail @@ -3286,5 +3295,89 @@ struct Model : detail::ModelImpl { explicit Model(const std::vector& opsets); #endif }; + +namespace detail { +template +struct ConstKernelDefImpl : Base { + using B = Base; + using B::B; + + ///< Wraps OrtEpApi::KernelDef_GetOperatorType + const char* GetOperatorType() const; + + ///< Wraps OrtEpApi::KernelDef_GetDomain + const char* GetDomain() const; + + ///< Wraps OrtEpApi::KernelDef_GetSinceVersion + std::pair GetSinceVersion() const; + + ///< Wraps OrtEpApi::KernelDef_GetExecutionProvider + const char* GetExecutionProvider() const; + + ///< Wraps OrtEpApi::KernelDef_GetInputMemType + OrtMemType GetInputMemType(size_t input_index) const; + + ///< Wraps OrtEpApi::KernelDef_GetOutputMemType + OrtMemType GetOutputMemType(size_t output_index) const; +}; +} // namespace detail + +using ConstKernelDef = detail::ConstKernelDefImpl>; + +struct KernelDef : detail::ConstKernelDefImpl { + using Base = detail::ConstKernelDefImpl; + using Base::Base; + + explicit KernelDef(std::nullptr_t) {} + explicit KernelDef(OrtKernelDef* p) : detail::ConstKernelDefImpl{p} {} + + ConstKernelDef GetConst() const { return ConstKernelDef{this->p_}; } +}; + +/** \brief Builder for OrtKernelDef. + * + * Used by plugin EPs to build a kernel definition. + */ +struct KernelDefBuilder : detail::Base { + KernelDefBuilder(); ///< Wraps OrtEpApi::CreateKernelDefBuilder + explicit KernelDefBuilder(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used + explicit KernelDefBuilder(OrtKernelDefBuilder* ort_kernel_def_builder); + + KernelDefBuilder& SetOperatorType(const char* op_type); + KernelDefBuilder& SetDomain(const char* domain); + KernelDefBuilder& SetSinceVersion(int since_version_start, int since_version_end); + KernelDefBuilder& SetExecutionProvider(const char* ep_name); + KernelDefBuilder& SetInputMemType(size_t input_index, OrtMemType mem_type); + KernelDefBuilder& SetOutputMemType(size_t output_index, OrtMemType mem_type); + KernelDefBuilder& AddTypeConstraint(const char* arg_name, const OrtDataType* data_type); + KernelDefBuilder& AddTypeConstraint(const char* arg_name, const std::vector& data_types); + KernelDefBuilder& AddInputOutputAlias(int input_index, int output_index); + KernelDefBuilder& AddInputOutputAliases(const std::vector& input_indices, + const std::vector& output_indices); + KernelDefBuilder& AddInputOutputMutableAlias(int input_index, int output_index); + KernelDefBuilder& AddInputOutputMutableAliases(const std::vector& input_indices, + const std::vector& output_indices); + + KernelDef Build(); +}; + +/** \brief Registry for kernels supported by an EP. + * + * Used by plugin EPs to register definitions for supported kernels. + */ +struct KernelRegistry : detail::Base { + ///< Wrapper around OrtEpApi::CreateKernelRegistry + KernelRegistry(); + + ///< Create an empty object, must be assigned a valid one to be used + explicit KernelRegistry(std::nullptr_t) {} + + ///< Take ownership of a pointer created with the C API. + explicit KernelRegistry(OrtKernelRegistry* ort_kernel_registry); + + ///< Wraps KernelRegistry_AddKernel + Status AddKernel(const OrtKernelDef* kernel_def, OrtKernelCreateFunc kernel_create_func, + void* kernel_create_func_state); +}; } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 5144418db2b58..aff1061a67fea 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -12,6 +12,7 @@ #include #include #include +#include #include // Convert OrtStatus to Ort::Status and return @@ -1503,6 +1504,18 @@ inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsUsingFunct return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::AddFreeDimensionOverride(const char* dim_denotation, int64_t dim_value) { + ThrowOnError(GetApi().AddFreeDimensionOverrideByName(this->p_, dim_denotation, dim_value)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AddFreeDimensionOverrideByName(const char* dim_name, int64_t dim_value) { + ThrowOnError(GetApi().AddFreeDimensionOverrideByName(this->p_, dim_name, dim_value)); + return *this; +} + /// Session template inline size_t ConstSessionImpl::GetInputCount() const { @@ -3560,4 +3573,144 @@ inline Model::Model(const std::vector& opsets) { } #endif +namespace detail { +template +inline const char* ConstKernelDefImpl::GetOperatorType() const { + return GetEpApi().KernelDef_GetOperatorType(this->p_); +} + +template +inline const char* ConstKernelDefImpl::GetDomain() const { + return GetEpApi().KernelDef_GetDomain(this->p_); +} + +template +inline std::pair ConstKernelDefImpl::GetSinceVersion() const { + int start = 0; + int end = 0; + + ThrowOnError(GetEpApi().KernelDef_GetSinceVersion(this->p_, &start, &end)); + return std::pair(start, end); +} + +template +inline const char* ConstKernelDefImpl::GetExecutionProvider() const { + return GetEpApi().KernelDef_GetExecutionProvider(this->p_); +} + +template +inline OrtMemType ConstKernelDefImpl::GetInputMemType(size_t input_index) const { + OrtMemType mem_type{}; + ThrowOnError(GetEpApi().KernelDef_GetInputMemType(this->p_, input_index, &mem_type)); + + return mem_type; +} + +template +inline OrtMemType ConstKernelDefImpl::GetOutputMemType(size_t output_index) const { + OrtMemType mem_type{}; + ThrowOnError(GetEpApi().KernelDef_GetOutputMemType(this->p_, output_index, &mem_type)); + + return mem_type; +} +} // namespace detail + +inline KernelDefBuilder::KernelDefBuilder() { + ThrowOnError(GetEpApi().CreateKernelDefBuilder(&p_)); +} + +inline KernelDefBuilder::KernelDefBuilder(OrtKernelDefBuilder* p) : detail::Base{p} { +} + +inline KernelDefBuilder& KernelDefBuilder::SetOperatorType(const char* op_type) { + ThrowOnError(GetEpApi().KernelDefBuilder_SetOperatorType(p_, op_type)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::SetDomain(const char* domain) { + ThrowOnError(GetEpApi().KernelDefBuilder_SetDomain(p_, domain)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::SetSinceVersion(int since_version_start, int since_version_end) { + ThrowOnError(GetEpApi().KernelDefBuilder_SetSinceVersion(p_, since_version_start, since_version_end)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::SetExecutionProvider(const char* ep_name) { + ThrowOnError(GetEpApi().KernelDefBuilder_SetExecutionProvider(p_, ep_name)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::SetInputMemType(size_t input_index, OrtMemType mem_type) { + ThrowOnError(GetEpApi().KernelDefBuilder_SetInputMemType(p_, input_index, mem_type)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::SetOutputMemType(size_t output_index, OrtMemType mem_type) { + ThrowOnError(GetEpApi().KernelDefBuilder_SetOutputMemType(p_, output_index, mem_type)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::AddTypeConstraint(const char* arg_name, + const OrtDataType* data_type) { + ThrowOnError(GetEpApi().KernelDefBuilder_AddTypeConstraint(p_, arg_name, &data_type, 1)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::AddTypeConstraint(const char* arg_name, + const std::vector& data_types) { + ThrowOnError(GetEpApi().KernelDefBuilder_AddTypeConstraint(p_, arg_name, data_types.data(), data_types.size())); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::AddInputOutputAlias(int input_index, int output_index) { + ThrowOnError(GetEpApi().KernelDefBuilder_AddInputOutputAliases(p_, &input_index, &output_index, 1)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::AddInputOutputAliases(const std::vector& input_indices, + const std::vector& output_indices) { + if (input_indices.size() != output_indices.size()) { + ORT_CXX_API_THROW("Expecting input and output indices to have the same element count", ORT_INVALID_ARGUMENT); + } + + ThrowOnError(GetEpApi().KernelDefBuilder_AddInputOutputAliases(p_, input_indices.data(), output_indices.data(), + input_indices.size())); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::AddInputOutputMutableAlias(int input_index, int output_index) { + ThrowOnError(GetEpApi().KernelDefBuilder_AddInputOutputMutableAliases(p_, &input_index, &output_index, 1)); + return *this; +} + +inline KernelDefBuilder& KernelDefBuilder::AddInputOutputMutableAliases(const std::vector& input_indices, + const std::vector& output_indices) { + if (input_indices.size() != output_indices.size()) { + ORT_CXX_API_THROW("Expecting input and output indices to have the same element count", ORT_INVALID_ARGUMENT); + } + + ThrowOnError(GetEpApi().KernelDefBuilder_AddInputOutputMutableAliases(p_, input_indices.data(), output_indices.data(), + input_indices.size())); + return *this; +} + +inline KernelDef KernelDefBuilder::Build() { + OrtKernelDef* kernel_def = nullptr; + ThrowOnError(GetEpApi().KernelDefBuilder_Build(p_, &kernel_def)); + return KernelDef(kernel_def); +} + +inline KernelRegistry::KernelRegistry() { + ThrowOnError(GetEpApi().CreateKernelRegistry(&p_)); +} + +inline KernelRegistry::KernelRegistry(OrtKernelRegistry* p) : detail::Base{p} { +} + +inline Status KernelRegistry::AddKernel(const OrtKernelDef* kernel_def, OrtKernelCreateFunc kernel_create_func, + void* kernel_create_func_state) { + return Status{GetEpApi().KernelRegistry_AddKernel(p_, kernel_def, kernel_create_func, kernel_create_func_state)}; +} } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index de38085914516..6fa5c8dea04e6 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -24,6 +24,12 @@ ORT_RUNTIME_CLASS(DataTransferImpl); ORT_RUNTIME_CLASS(SyncNotificationImpl); ORT_RUNTIME_CLASS(SyncStreamImpl); +// Opaque types for kernel-based EPs +ORT_RUNTIME_CLASS(KernelRegistry); +ORT_RUNTIME_CLASS(KernelDefBuilder); +ORT_RUNTIME_CLASS(KernelDef); +ORT_RUNTIME_CLASS(DataType); // combination of ONNXType (e.g., Tensor, Map, Sequence) and ONNXTensorElementDataType + /** \brief Struct that an EP implements for IDataTransfer to copy between devices it uses and CPU. * * \since Version 1.23. @@ -274,6 +280,51 @@ struct OrtNodeComputeInfo { void(ORT_API_CALL* ReleaseState)(_In_ OrtNodeComputeInfo* this_ptr, _Frees_ptr_opt_ void* compute_state); }; +struct OrtKernelImpl; +typedef struct OrtKernelImpl OrtKernelImpl; + +/** + * \brief Contains functions that an OrtEp implements to specify the computation for an operator kernel. + * \since Version 1.24. + */ +struct OrtKernelImpl { + uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + + /** \brief Computation function called to execute the kernel on an EP. + * + * \param[in] this_ptr The OrtKernelImpl instance. + * \param[in] context The OrtKernelContext instance that provides access to the inputs and outputs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(Compute, _In_ OrtKernelImpl* this_ptr, _In_ OrtKernelContext* context); + + /** \brief Called by ORT to release the OrtKernelImpl instance and its resources. + * + * \param[in] this_ptr The OrtKernelImpl instance. + * + * \since Version 1.24. + */ + ORT_API_T(void, Release, _In_ OrtKernelImpl* this_ptr); +}; + +/** \brief Type definition for a function that creates an OrtKernelImpl instance for an operator kernel. + * + * \param[in] kernel_create_func_state Opaque state initially provided by the EP that registered the kernel. + * Refer to OrtEpApi::KernelRegistry_AddKernel(). May be null. + * \param[in] info The OrtKernelInfo instance that provides access to the kernel's input and output characteristics. + * \param[out] kernel_out Output parameter set to the new OrtKernelImpl instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ +typedef OrtStatus*(ORT_API_CALL* OrtKernelCreateFunc)(_In_ void* kernel_create_func_state, + _In_ const OrtKernelInfo* info, + _Outptr_result_maybenull_ OrtKernelImpl** kernel_out); + /** * \brief The OrtEpApi struct provides functions that are relevant to the implementation of an execution provider. * @@ -507,6 +558,294 @@ struct OrtEpApi { _Out_ OrtHardwareDevice** hardware_device); ORT_CLASS_RELEASE(HardwareDevice); + + /** \brief Creates an empty kernel registry. A kernel registry contains kernel creation information for + * every operator kernel supported by an EP. + * + * \remarks Refer to OrtEp::GetKernelRegistry, which returns an EP's kernel registry to ORT. + * + * \param[out] kernel_registry Output parameter set to the new OrtKernelRegistry instance. + * Must be released with OrtEpApi::ReleaseKernelRegistry. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateKernelRegistry, _Outptr_ OrtKernelRegistry** kernel_registry); + + ORT_CLASS_RELEASE(KernelRegistry); + + /** \brief Adds kernel creation information for a supported operator kernel to the given kernel registry. + * + * \remarks Refer to OrtEp::GetKernelRegistry, which returns an EP's kernel registry to ORT. + * + * \param[in] kernel_registry The OrtKernelRegistry instance. + * \param[in] kernel_def The kernel definition, which includes operator type, version, EP name, type constraints, etc. + * \param[in] kernel_create_func Function that creates an instance of the operator kernel as a OrtKernelImpl instance. + * \param[in] kernel_create_func_state Custom state passed to the kernel creation function. Can be null. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelRegistry_AddKernel, _In_ OrtKernelRegistry* kernel_registry, + _In_ const OrtKernelDef* kernel_def, _In_ OrtKernelCreateFunc kernel_create_func, + _In_ void* kernel_create_func_state); + + /** \brief Creates a kernel definition builder used to create instances of OrtKernelDef. + * + * \param[out] kernel_def_builder_out Output parameter set to the new OrtKernelDefBuilder instance. + * Must be released with OrtEpApi::ReleaseKernelDefBuilder(). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateKernelDefBuilder, _Outptr_ OrtKernelDefBuilder** kernel_def_builder_out); + + ORT_CLASS_RELEASE(KernelDefBuilder); + + /** \brief Sets the kernel's operator type. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] op_type A null-terminated string representing the operator type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_SetOperatorType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* op_type); + + /** \brief Sets the kernel's domain. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] domain A null-terminated string representing the operator's domain. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_SetDomain, _In_ OrtKernelDefBuilder* kernel_def_builder, _In_ const char* domain); + + /** \brief Sets the kernel's opset version range that is supported. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] since_version_start The starting opset version that is supported. + * \param[in] since_version_end The ending opset version (inclusive) that is supported. + * Can be set equal to the starting version to indicate that only one + * version is supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_SetSinceVersion, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ int since_version_start, _In_ int since_version_end); + + /** \brief Sets the name of the kernel's intended execution provider. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] ep_name A null-terminated string representing the execution provider's name. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_SetExecutionProvider, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* ep_name); + + /** \brief Sets the memory type for a kernel input. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] input_index The index of the input. + * \param[in] mem_type The input's memory type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_SetInputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ size_t input_index, _In_ OrtMemType mem_type); + + /** \brief Sets the memory type for a kernel output. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] output_index The index of the output. + * \param[in] mem_type The output's memory type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_SetOutputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ size_t output_index, _In_ OrtMemType mem_type); + + /** \brief Adds type constraints for a kernel argument represented as a string (e.g., "T"). + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] arg_name A null-terminated string representing the argument to constrain (e.g., "T"). + * \param[in] types Array of OrtDataType instances representing allowed types for the argument. + * Must contain `num_types` elements. + * \param[in] num_types The number of OrtDataType elements in the `types` array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_AddTypeConstraint, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* arg_name, _In_reads_(num_types) const OrtDataType* const* types, + _In_ size_t num_types); + + /** \brief Adds aliases for the given input and output pairs. + * + * \note Used for operators like Identity and Reshape to allow ORT to reuse the input buffer for the output + * without modification. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] input_indices Array of input indices. Array must contain `num_io_indices` elements. + * \param[in] output_indices Array of output indices. Each output index is aliased with a corresponding + * input index in `input_indices`. Array must contain `num_io_indices` elements. + * \param[in] num_io_indices The number of input/output index pairs to alias. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_AddInputOutputAliases, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_reads_(num_io_indices) int const* input_indices, + _In_reads_(num_io_indices) int const* output_indices, + _In_ size_t num_io_indices); + + /** \brief Adds mutable aliases for the given input and output pairs. + * + * \note Allows ORT to reuse and *modify* an input buffer (in-place) for the output buffer. + * This is also known as "MayInplace" within the ORT codebase. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[in] input_indices Array of input indices. Array must contain `num_io_indices` elements. + * \param[in] output_indices Array of output indices. Each output index is aliased with a corresponding + * input index in `input_indices`. Array must contain `num_io_indices` elements. + * \param[in] num_io_indices The number of input/output index pairs to alias. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_AddInputOutputMutableAliases, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_reads_(num_io_indices) int const* input_indices, + _In_reads_(num_io_indices) int const* output_indices, + _In_ size_t num_io_indices); + + /** \brief Creates a OrtKernelDef instance from the given kernel definition builder. + * + * \param[in] kernel_def_builder The OrtKernelDefBuilder instance. + * \param[out] kernel_def_out The new OrtKernelDef instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDefBuilder_Build, _In_ OrtKernelDefBuilder* kernel_def_builder, + _Outptr_ OrtKernelDef** kernel_def_out); + + ORT_CLASS_RELEASE(KernelDef); + + /** \brief Returns the operator type from the kernel definition. + * + * \param[in] kernel_def The OrtKernelDef instance. + * \return A null-terminated string representing the operator type. + * + * \since Version 1.24. + */ + ORT_API_T(const char*, KernelDef_GetOperatorType, _In_ const OrtKernelDef* kernel_def); + + /** \brief Returns the operator's domain from the kernel definition. + * + * \param[in] kernel_def The OrtKernelDef instance. + * \return A null-terminated string representing the operator's domain. + * + * \since Version 1.24. + */ + ORT_API_T(const char*, KernelDef_GetDomain, _In_ const OrtKernelDef* kernel_def); + + /** \brief Gets the kernel's opset version range that is supported. + * + * \param[in] kernel_def The OrtKernelDef instance. + * \param[out] version_start Output parameter set to the starting opset version that is supported. + * \param[out] version_end Output parameter set to the ending opset version (inclusive) that is supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDef_GetSinceVersion, _In_ const OrtKernelDef* kernel_def, + _Out_ int* start_version, _Out_ int* end_version); + + /** \brief Returns the name of the kernel's intended execution provider. + * + * \param[in] kernel_def The OrtKernelDef instance. + * \return A null-terminated string representing the name of the execution provider. + * + * \since Version 1.24. + */ + ORT_API_T(const char*, KernelDef_GetExecutionProvider, _In_ const OrtKernelDef* kernel_def); + + /** \brief Gets the memory type for a kernel input. + * + * \param[in] kernel_def The OrtKernelDef instance. + * \param[in] input_index The index of the input. + * \param[out] mem_type Output parameter set to the input's memory type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDef_GetInputMemType, _In_ const OrtKernelDef* kernel_def, + _In_ size_t input_index, _Out_ OrtMemType* mem_type); + + /** \brief Gets the memory type for a kernel output. + * + * \param[in] kernel_def The OrtKernelDef instance. + * \param[in] output_index The index of the output. + * \param[out] mem_type Output parameter set to the output's memory type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(KernelDef_GetOutputMemType, _In_ const OrtKernelDef* kernel_def, + _In_ size_t output_index, _Out_ OrtMemType* mem_type); + + /** \brief Gets the OrtDataType that represents the data type for a tensor of the given element type. + * + * \param[in] elem_type The tensor's element type. + * \param[out] out Output parameter set to the OrtDataType. Owned by ORT and must not be released. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetTensorDataType, _In_ ONNXTensorElementDataType elem_type, + _Outptr_ const OrtDataType** out); + + /** \brief Gets the kernel definition for a given node, if any exists for the calling execution provider. + * + * Used within OrtEp::GetCapability() to get the registered kernel definition for the given node. + * The kernel definition is set to NULL if there is no registered kernel definition for the node + * and execution provider. + * + * \param[in] graph_support_info The OrtEpGraphSupportInfo instance to query. + * \param[in] node The node for which to look up a kernel definition. + * \param[out] out_kernel_def Output parameter set to the OrtKernelDef or NULL. + * Owned by ORT and must not be released. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def); }; /** @@ -603,6 +942,9 @@ struct OrtEp { * graphs are only valid for the duration of the call to Compile. Any graph/node/input/output * names that are needed by the OrtNodeComputeInfo functions must be copied and stored by the OrtEp. * + * \note As of version 1.24, implementation of this function is optional if the EP does not compile nodes and + * uses a kernel registry instead. + * * \since Version 1.23. */ ORT_API2_STATUS(Compile, _In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, @@ -616,6 +958,9 @@ struct OrtEp { * \param[inout] node_compute_infos The OrtNodeComputeInfo instances to release. * \param[in] num_node_compute_infos The number of OrtNodeComputeInfo instances. * + * \note As of version 1.24, implementation of this function is optional if the EP does not compile nodes and + * uses a kernel registry instead. + * * \since Version 1.23. */ ORT_API_T(void, ReleaseNodeComputeInfos, _In_ OrtEp* this_ptr, @@ -768,6 +1113,22 @@ struct OrtEp { */ ORT_API_T(const char*, GetCompiledModelCompatibilityInfo, _In_ OrtEp* this_ptr, _In_ const OrtGraph* graph); + + /** \brief Gets the execution provider's kernel registry, if any. + * + * A kernel registry contains kernel creation information for operator kernels supported by an EP. + * + * \param[in] this_ptr The OrtEp instance. + * \param[out] kernel_registry Output parameter set to the EP's kernel registry, which must remain valid throughout + * the lifetime of the EP. Can be NULL if the EP doesn't use a kernel registry. + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Implementation of this function is optional. If set to NULL, ORT assumes the EP compiles nodes. + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetKernelRegistry, _In_ OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index 5002e16ba116c..81c20768c3120 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -447,18 +447,6 @@ struct OrtLiteCustomOp : public OrtCustomOp { } #endif -#ifdef ORT_ROCM_CTX - template - static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { - thread_local RocmContext rocm_context; - rocm_context.Init(*context); - std::tuple current = std::tuple{rocm_context}; - auto next = CreateTuple(context, args, num_input, num_output, ep); - return std::tuple_cat(current, next); - } -#endif - template static typename std::enable_if::value, std::tuple>::type CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { @@ -674,14 +662,6 @@ struct OrtLiteCustomOp : public OrtCustomOp { } #endif -#ifdef ORT_ROCM_CTX - template - static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type - ParseArgs(std::vector& input_types, std::vector& output_types) { - ParseArgs(input_types, output_types); - } -#endif - template static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type ParseArgs(std::vector& input_types, std::vector& output_types) { diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index c202b2a9f80e0..71505f51efaca 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -2127,9 +2127,6 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet provid case ARM_NN: options.addArmNN(false); break; - case ROCM: - options.addROCM(); - break; case CORE_ML: options.addCoreML(); break; diff --git a/js/package-lock.json b/js/package-lock.json index 1e9f5cb29fe6c..0fca515b61238 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -4,6 +4,7 @@ "requires": true, "packages": { "": { + "name": "js", "license": "MIT", "devDependencies": { "@eslint/compat": "^1.4.0", @@ -3230,6 +3231,27 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/glob": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", + "dev": true, + "license": "ISC", + "dependencies": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -3242,6 +3264,32 @@ "node": ">=10.13.0" } }, + "node_modules/glob/node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/glob/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/global-agent": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-3.0.0.tgz", @@ -4311,43 +4359,6 @@ "balanced-match": "^1.0.0" } }, - "node_modules/mocha/node_modules/glob": { - "version": "10.4.5", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", - "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", - "dev": true, - "license": "ISC", - "dependencies": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "bin": { - "glob": "dist/esm/bin.mjs" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/mocha/node_modules/glob/node_modules/minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, - "license": "ISC", - "dependencies": { - "brace-expansion": "^2.0.1" - }, - "engines": { - "node": ">=16 || 14 >=14.17" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/mocha/node_modules/minimatch": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", @@ -8078,6 +8089,40 @@ "get-intrinsic": "^1.2.6" } }, + "glob": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", + "dev": true, + "requires": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "dependencies": { + "brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "requires": { + "balanced-match": "^1.0.0" + } + }, + "minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "requires": { + "brace-expansion": "^2.0.1" + } + } + } + }, "glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -8772,31 +8817,6 @@ "balanced-match": "^1.0.0" } }, - "glob": { - "version": "10.4.5", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", - "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", - "dev": true, - "requires": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "dependencies": { - "minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, - "requires": { - "brace-expansion": "^2.0.1" - } - } - } - }, "minimatch": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json index e6ed2bdb9e17b..de8d631362db7 100644 --- a/js/react_native/package-lock.json +++ b/js/react_native/package-lock.json @@ -33,6 +33,7 @@ "version": "1.24.0", "license": "MIT", "devDependencies": { + "globby": "^15.0.0", "typedoc": "^0.25.7" } }, @@ -61,15 +62,15 @@ } }, "node_modules/@babel/code-frame": { - "version": "7.26.2", - "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", - "integrity": "sha512-RJlIHRueQgwWitWgF8OdFYGZX328Ax5BCemNGlqHfplnRT9ESi8JkFlvaVYbS+UubVY6dpv87Fs2u5M29iNFVQ==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.27.1.tgz", + "integrity": "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg==", "dev": true, "license": "MIT", "dependencies": { - "@babel/helper-validator-identifier": "^7.25.9", + "@babel/helper-validator-identifier": "^7.27.1", "js-tokens": "^4.0.0", - "picocolors": "^1.0.0" + "picocolors": "^1.1.1" }, "engines": { "node": ">=6.9.0" @@ -410,9 +411,9 @@ } }, "node_modules/@babel/helper-string-parser": { - "version": "7.25.9", - "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.25.9.tgz", - "integrity": "sha512-4A/SCr/2KLd5jrtOMFzaKjVtAei3+2r/NChoBNoZ3EyP/+GlhoaEGoWOZUmFmoITP7zOJyHIMm+DYRd8o3PvHA==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", + "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", "dev": true, "license": "MIT", "engines": { @@ -420,9 +421,9 @@ } }, "node_modules/@babel/helper-validator-identifier": { - "version": "7.25.9", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.25.9.tgz", - "integrity": "sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.28.5.tgz", + "integrity": "sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==", "dev": true, "license": "MIT", "engines": { @@ -455,27 +456,27 @@ } }, "node_modules/@babel/helpers": { - "version": "7.25.6", - "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.25.6.tgz", - "integrity": "sha512-Xg0tn4HcfTijTwfDwYlvVCl43V6h4KyVVX2aEm4qdO/PC6L2YvzLHFdmxhoeSA3eslcE6+ZVXHgWwopXYLNq4Q==", + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.28.4.tgz", + "integrity": "sha512-HFN59MmQXGHVyYadKLVumYsA9dBFun/ldYxipEjzA4196jpLZd8UjEEBLkbEkvfYreDqJhZxYAWFPtrfhNpj4w==", "dev": true, "license": "MIT", "dependencies": { - "@babel/template": "^7.25.0", - "@babel/types": "^7.25.6" + "@babel/template": "^7.27.2", + "@babel/types": "^7.28.4" }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/parser": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.9.tgz", - "integrity": "sha512-81NWa1njQblgZbQHxWHpxxCzNsa3ZwvFqpUg7P+NNUU6f3UU2jBEg4OlF/J6rl8+PQGh1q6/zWScd001YwcA5A==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.28.5.tgz", + "integrity": "sha512-KKBU1VGYR7ORr3At5HAtUQ+TV3SzRCXmA/8OdDZiLDBIZxVyzXuztPjfLd3BV1PRAQGCMWWSHYhL0F8d5uHBDQ==", "dev": true, "license": "MIT", "dependencies": { - "@babel/types": "^7.26.9" + "@babel/types": "^7.28.5" }, "bin": { "parser": "bin/babel-parser.js" @@ -2114,35 +2115,25 @@ } }, "node_modules/@babel/runtime": { - "version": "7.25.6", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.25.6.tgz", - "integrity": "sha512-VBj9MYyDb9tuLq7yzqjgzt6Q+IBQLrGZfdjOekyEirZPHxXWoTSGUTMrpsfi58Up73d13NfYLv8HT9vmznjzhQ==", + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.4.tgz", + "integrity": "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==", "dev": true, "license": "MIT", - "dependencies": { - "regenerator-runtime": "^0.14.0" - }, "engines": { "node": ">=6.9.0" } }, - "node_modules/@babel/runtime/node_modules/regenerator-runtime": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.1.tgz", - "integrity": "sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==", - "dev": true, - "license": "MIT" - }, "node_modules/@babel/template": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.26.9.tgz", - "integrity": "sha512-qyRplbeIpNZhmzOysF/wFMuP9sctmh2cFzRAZOn1YapxBsE1i9bJIY586R/WBLfLcmcBlM8ROBiQURnnNy+zfA==", + "version": "7.27.2", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.2.tgz", + "integrity": "sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==", "dev": true, "license": "MIT", "dependencies": { - "@babel/code-frame": "^7.26.2", - "@babel/parser": "^7.26.9", - "@babel/types": "^7.26.9" + "@babel/code-frame": "^7.27.1", + "@babel/parser": "^7.27.2", + "@babel/types": "^7.27.1" }, "engines": { "node": ">=6.9.0" @@ -2189,14 +2180,14 @@ "license": "MIT" }, "node_modules/@babel/types": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.9.tgz", - "integrity": "sha512-Y3IR1cRnOxOCDvMmNiym7XpXQ93iGDDPHx+Zj+NM+rg0fBaShfQLkg+hKPaZCEvg5N/LeCo4+Rj/i3FuJsIQaw==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.28.5.tgz", + "integrity": "sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==", "dev": true, "license": "MIT", "dependencies": { - "@babel/helper-string-parser": "^7.25.9", - "@babel/helper-validator-identifier": "^7.25.9" + "@babel/helper-string-parser": "^7.27.1", + "@babel/helper-validator-identifier": "^7.28.5" }, "engines": { "node": ">=6.9.0" @@ -3319,9 +3310,9 @@ } }, "node_modules/babel-plugin-module-resolver/node_modules/brace-expansion": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", - "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "dev": true, "license": "MIT", "dependencies": { @@ -3477,7 +3468,9 @@ } }, "node_modules/brace-expansion": { - "version": "1.1.11", + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", "dev": true, "license": "MIT", "dependencies": { @@ -3831,9 +3824,9 @@ } }, "node_modules/compression": { - "version": "1.8.0", - "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.0.tgz", - "integrity": "sha512-k6WLKfunuqCYD3t6AsuPGvQWaKwuLLh2/xHNcX4qE+vIfDNXpSqnrhwA7O53R7WVQUnt8dVAIW+YHr7xTgOgGA==", + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.1.tgz", + "integrity": "sha512-9mAqGPHLakhCLeNyxPkK4xVo746zQ/czLH1Ky+vkitMnWfWZps8r0qXuwhwizagCRttsL4lfG4pIOvaWLpAP0w==", "dev": true, "license": "MIT", "dependencies": { @@ -3841,7 +3834,7 @@ "compressible": "~2.0.18", "debug": "2.6.9", "negotiator": "~0.6.4", - "on-headers": "~1.0.2", + "on-headers": "~1.1.0", "safe-buffer": "5.2.1", "vary": "~1.1.2" }, @@ -4821,9 +4814,9 @@ } }, "node_modules/image-size": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.0.tgz", - "integrity": "sha512-4S8fwbO6w3GeCVN6OPtA9I5IGKkcDMPcKndtUlpJuCwu7JLjtj7JZpwqLuyY2nrmQT3AWsCJLSKPsc2mPBSl3w==", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.1.tgz", + "integrity": "sha512-rH+46sQJ2dlwfjfhCyNx5thzrv+dtmBIhPHk0zgRUukHzZ/kRueTJXoYYsclBaKcSMBWuGbOFXtioLpzTb5euw==", "dev": true, "license": "MIT", "dependencies": { @@ -5250,7 +5243,9 @@ "license": "MIT" }, "node_modules/js-yaml": { - "version": "3.14.1", + "version": "3.14.2", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.2.tgz", + "integrity": "sha512-PMSmkqxr106Xa156c2M265Z+FTrPl+oxd/rgOQy2tijQeK5TxQ43psO1ZCwhVOSdnn+RzkzlRz/eY4BgJBYVpg==", "dev": true, "license": "MIT", "dependencies": { @@ -6544,9 +6539,9 @@ } }, "node_modules/on-headers": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz", - "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.1.0.tgz", + "integrity": "sha512-737ZY3yNnXy37FHkQxPzt4UZ2UWPWiCZWLvFZ4fu5cueciegX0zGPnrlY6bwRg4FdQOe9YU8MkmJwGhoMybl8A==", "dev": true, "license": "MIT", "engines": { @@ -7130,9 +7125,9 @@ "license": "Python-2.0" }, "node_modules/react-native-builder-bob/node_modules/brace-expansion": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", - "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "dev": true, "license": "MIT", "dependencies": { @@ -7203,9 +7198,9 @@ } }, "node_modules/react-native-builder-bob/node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "license": "MIT", "dependencies": { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 6a8dffb73fa08..f0f7527f665b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -360,7 +360,7 @@ const createInPlaceSoftmaxProgramInfo = ( let local_offset = local_idx * uniforms.elements_per_thread; let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset; let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'}; - var thread_max_vector = ${f32Type}(-3.402823e+38f); + var thread_max_vector = ${f32Type}(-3.4028234663852886e+38f); for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector); } @@ -378,7 +378,7 @@ const createInPlaceSoftmaxProgramInfo = ( })()}; workgroupBarrier(); - var max_value = f32(-3.402823e+38f); + var max_value = f32(-3.4028234663852886e+38f); for (var i = 0u; i < ${WG}; i++) { max_value = max(thread_max[i], max_value); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 2056416873df5..f6882280e91df 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -81,7 +81,7 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt // 6.2.4 in wgsl spec const threadMaxDecl = tensorTypeToWsglStorageType(transposedInput.dataType) === 'f32' - ? `var threadMax = ${valueType}(-3.402823e+38f);` + ? `var threadMax = ${valueType}(-3.4028234663852886e+38f);` : `var threadMax = ${valueType}(-65504.0h);`; const getShaderSource = (shaderHelper: ShaderHelper) => ` var rowMaxShared : ${valueType}; diff --git a/js/web/test/e2e/exports/testcases/nextjs-default/package-lock.json b/js/web/test/e2e/exports/testcases/nextjs-default/package-lock.json index facda7bcaf2c0..4baa6dd41d9a8 100644 --- a/js/web/test/e2e/exports/testcases/nextjs-default/package-lock.json +++ b/js/web/test/e2e/exports/testcases/nextjs-default/package-lock.json @@ -8,7 +8,7 @@ "name": "nextjs-default", "version": "0.1.0", "dependencies": { - "next": "15.4.7", + "next": "15.4.8", "react": "^19.0.0", "react-dom": "^19.0.0" } @@ -442,15 +442,15 @@ } }, "node_modules/@next/env": { - "version": "15.4.7", - "resolved": "https://registry.npmjs.org/@next/env/-/env-15.4.7.tgz", - "integrity": "sha512-PrBIpO8oljZGTOe9HH0miix1w5MUiGJ/q83Jge03mHEE0E3pyqzAy2+l5G6aJDbXoobmxPJTVhbCuwlLtjSHwg==", + "version": "15.4.8", + "resolved": "https://registry.npmjs.org/@next/env/-/env-15.4.8.tgz", + "integrity": "sha512-LydLa2MDI1NMrOFSkO54mTc8iIHSttj6R6dthITky9ylXV2gCGi0bHQjVCtLGRshdRPjyh2kXbxJukDtBWQZtQ==", "license": "MIT" }, "node_modules/@next/swc-darwin-arm64": { - "version": "15.4.7", - "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-15.4.7.tgz", - "integrity": "sha512-2Dkb+VUTp9kHHkSqtws4fDl2Oxms29HcZBwFIda1X7Ztudzy7M6XF9HDS2dq85TmdN47VpuhjE+i6wgnIboVzQ==", + "version": "15.4.8", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-15.4.8.tgz", + "integrity": "sha512-Pf6zXp7yyQEn7sqMxur6+kYcywx5up1J849psyET7/8pG2gQTVMjU3NzgIt8SeEP5to3If/SaWmaA6H6ysBr1A==", "cpu": [ "arm64" ], @@ -464,9 +464,9 @@ } }, "node_modules/@next/swc-darwin-x64": { - "version": "15.4.7", - "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-15.4.7.tgz", - "integrity": "sha512-qaMnEozKdWezlmh1OGDVFueFv2z9lWTcLvt7e39QA3YOvZHNpN2rLs/IQLwZaUiw2jSvxW07LxMCWtOqsWFNQg==", + "version": "15.4.8", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-15.4.8.tgz", + "integrity": "sha512-xla6AOfz68a6kq3gRQccWEvFC/VRGJmA/QuSLENSO7CZX5WIEkSz7r1FdXUjtGCQ1c2M+ndUAH7opdfLK1PQbw==", "cpu": [ "x64" ], @@ -480,9 +480,9 @@ } }, "node_modules/@next/swc-linux-arm64-gnu": { - "version": "15.4.7", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-15.4.7.tgz", - "integrity": "sha512-ny7lODPE7a15Qms8LZiN9wjNWIeI+iAZOFDOnv2pcHStncUr7cr9lD5XF81mdhrBXLUP9yT9RzlmSWKIazWoDw==", + "version": "15.4.8", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-15.4.8.tgz", + "integrity": "sha512-y3fmp+1Px/SJD+5ntve5QLZnGLycsxsVPkTzAc3zUiXYSOlTPqT8ynfmt6tt4fSo1tAhDPmryXpYKEAcoAPDJw==", "cpu": [ "arm64" ], @@ -496,9 +496,9 @@ } }, "node_modules/@next/swc-linux-arm64-musl": { - "version": "15.4.7", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-15.4.7.tgz", - "integrity": "sha512-4SaCjlFR/2hGJqZLLWycccy1t+wBrE/vyJWnYaZJhUVHccpGLG5q0C+Xkw4iRzUIkE+/dr90MJRUym3s1+vO8A==", + "version": "15.4.8", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-15.4.8.tgz", + "integrity": "sha512-DX/L8VHzrr1CfwaVjBQr3GWCqNNFgyWJbeQ10Lx/phzbQo3JNAxUok1DZ8JHRGcL6PgMRgj6HylnLNndxn4Z6A==", "cpu": [ "arm64" ], @@ -512,9 +512,9 @@ } }, "node_modules/@next/swc-linux-x64-gnu": { - "version": "15.4.7", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-15.4.7.tgz", - "integrity": "sha512-2uNXjxvONyRidg00VwvlTYDwC9EgCGNzPAPYbttIATZRxmOZ3hllk/YYESzHZb65eyZfBR5g9xgCZjRAl9YYGg==", + "version": "15.4.8", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-15.4.8.tgz", + "integrity": "sha512-9fLAAXKAL3xEIFdKdzG5rUSvSiZTLLTCc6JKq1z04DR4zY7DbAPcRvNm3K1inVhTiQCs19ZRAgUerHiVKMZZIA==", "cpu": [ "x64" ], @@ -528,9 +528,9 @@ } }, "node_modules/@next/swc-linux-x64-musl": { - "version": "15.4.7", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-15.4.7.tgz", - "integrity": "sha512-ceNbPjsFgLscYNGKSu4I6LYaadq2B8tcK116nVuInpHHdAWLWSwVK6CHNvCi0wVS9+TTArIFKJGsEyVD1H+4Kg==", + "version": "15.4.8", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-15.4.8.tgz", + "integrity": "sha512-s45V7nfb5g7dbS7JK6XZDcapicVrMMvX2uYgOHP16QuKH/JA285oy6HcxlKqwUNaFY/UC6EvQ8QZUOo19cBKSA==", "cpu": [ "x64" ], @@ -544,9 +544,9 @@ } }, "node_modules/@next/swc-win32-arm64-msvc": { - "version": "15.4.7", - "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-15.4.7.tgz", - "integrity": "sha512-pZyxmY1iHlZJ04LUL7Css8bNvsYAMYOY9JRwFA3HZgpaNKsJSowD09Vg2R9734GxAcLJc2KDQHSCR91uD6/AAw==", + "version": "15.4.8", + "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-15.4.8.tgz", + "integrity": "sha512-KjgeQyOAq7t/HzAJcWPGA8X+4WY03uSCZ2Ekk98S9OgCFsb6lfBE3dbUzUuEQAN2THbwYgFfxX2yFTCMm8Kehw==", "cpu": [ "arm64" ], @@ -560,9 +560,9 @@ } }, "node_modules/@next/swc-win32-x64-msvc": { - "version": "15.4.7", - "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-15.4.7.tgz", - "integrity": "sha512-HjuwPJ7BeRzgl3KrjKqD2iDng0eQIpIReyhpF5r4yeAHFwWRuAhfW92rWv/r3qeQHEwHsLRzFDvMqRjyM5DI6A==", + "version": "15.4.8", + "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-15.4.8.tgz", + "integrity": "sha512-Exsmf/+42fWVnLMaZHzshukTBxZrSwuuLKFvqhGHJ+mC1AokqieLY/XzAl3jc/CqhXLqLY3RRjkKJ9YnLPcRWg==", "cpu": [ "x64" ], @@ -691,12 +691,12 @@ } }, "node_modules/next": { - "version": "15.4.7", - "resolved": "https://registry.npmjs.org/next/-/next-15.4.7.tgz", - "integrity": "sha512-OcqRugwF7n7mC8OSYjvsZhhG1AYSvulor1EIUsIkbbEbf1qoE5EbH36Swj8WhF4cHqmDgkiam3z1c1W0J1Wifg==", + "version": "15.4.8", + "resolved": "https://registry.npmjs.org/next/-/next-15.4.8.tgz", + "integrity": "sha512-jwOXTz/bo0Pvlf20FSb6VXVeWRssA2vbvq9SdrOPEg9x8E1B27C2rQtvriAn600o9hH61kjrVRexEffv3JybuA==", "license": "MIT", "dependencies": { - "@next/env": "15.4.7", + "@next/env": "15.4.8", "@swc/helpers": "0.5.15", "caniuse-lite": "^1.0.30001579", "postcss": "8.4.31", @@ -709,14 +709,14 @@ "node": "^18.18.0 || ^19.8.0 || >= 20.0.0" }, "optionalDependencies": { - "@next/swc-darwin-arm64": "15.4.7", - "@next/swc-darwin-x64": "15.4.7", - "@next/swc-linux-arm64-gnu": "15.4.7", - "@next/swc-linux-arm64-musl": "15.4.7", - "@next/swc-linux-x64-gnu": "15.4.7", - "@next/swc-linux-x64-musl": "15.4.7", - "@next/swc-win32-arm64-msvc": "15.4.7", - "@next/swc-win32-x64-msvc": "15.4.7", + "@next/swc-darwin-arm64": "15.4.8", + "@next/swc-darwin-x64": "15.4.8", + "@next/swc-linux-arm64-gnu": "15.4.8", + "@next/swc-linux-arm64-musl": "15.4.8", + "@next/swc-linux-x64-gnu": "15.4.8", + "@next/swc-linux-x64-musl": "15.4.8", + "@next/swc-win32-arm64-msvc": "15.4.8", + "@next/swc-win32-x64-msvc": "15.4.8", "sharp": "^0.34.3" }, "peerDependencies": { @@ -781,6 +781,7 @@ "resolved": "https://registry.npmjs.org/react/-/react-19.0.0.tgz", "integrity": "sha512-V8AVnmPIICiWpGfm6GLzCR/W5FXLchHop40W4nXBmdlEceh16rCN8O8LNWm5bh5XUX91fh7KpA+W0TgMKmgTpQ==", "license": "MIT", + "peer": true, "engines": { "node": ">=0.10.0" } @@ -790,6 +791,7 @@ "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.0.0.tgz", "integrity": "sha512-4GV5sHFG0e/0AD4X+ySy6UJd3jVl1iNsNHdpad0qhABJ11twS3TTBnseqsKurKcsNqCEFeGL3uLpVChpIO3QfQ==", "license": "MIT", + "peer": true, "dependencies": { "scheduler": "^0.25.0" }, diff --git a/js/web/test/e2e/exports/testcases/nextjs-default/package.json b/js/web/test/e2e/exports/testcases/nextjs-default/package.json index e99b9e483f481..a285091748034 100644 --- a/js/web/test/e2e/exports/testcases/nextjs-default/package.json +++ b/js/web/test/e2e/exports/testcases/nextjs-default/package.json @@ -11,6 +11,6 @@ "dependencies": { "react": "^19.0.0", "react-dom": "^19.0.0", - "next": "15.4.7" + "next": "15.4.8" } } diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 8af6faadd6e92..ba6da7284247f 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -346,12 +346,11 @@ Status CheckInputs(const T* query, // // The following inputs are not used in cross attention (so they are None for cross attention): // past_key : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True. - // For CUDA, past_present_share_buffer is always True. ROCm supports both. + // For CUDA, past_present_share_buffer is always True. // past_value : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True. - // For CUDA, past_present_share_buffer is always True. ROCm supports both. + // For CUDA, past_present_share_buffer is always True. // past_sequence_length : scalar (1) when past_present_share_buffer is True. // CUDA version has extra inputs (beam_width, cache_indirection) that are not checked in the class. - // For ROCm, see contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh for more details. // --------------------------------------------------------------- AttentionQkvFormat qkv_format = UNKNOWN; diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index ceb498372a6fc..2bba0adcd987c 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME2() #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/mlas/inc/mlas.h" @@ -213,9 +212,7 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { } } - // Currently, MlasDynamicQGemmBatch() and associated functions require SME2 or else they are no-ops. - // We check that here too before attempting to use them. - if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME2()) { + if (!MlasIsDynamicQGemmAvailable()) { can_use_dynamic_quant_mlas_ = false; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc index 537d066b264a1..4597b9c7d6605 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -132,8 +132,7 @@ const IExecutionProvider* Subgraph::GetProvider() const { const ExecutionProviders& providers = session_state_->GetExecutionProviders(); const IExecutionProvider* cpu_provider = providers.Get(onnxruntime::kCpuExecutionProvider); const IExecutionProvider* cuda_provider = providers.Get(onnxruntime::kCudaExecutionProvider); - const IExecutionProvider* rocm_provider = providers.Get(onnxruntime::kRocmExecutionProvider); - const IExecutionProvider* gpu_provider = cuda_provider ? cuda_provider : rocm_provider; + const IExecutionProvider* gpu_provider = cuda_provider; const IExecutionProvider* provider = gpu_provider ? gpu_provider : cpu_provider; return provider; } diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 0c4d75aeddac0..0e97d5387c0a5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -216,7 +216,6 @@ __global__ void AddBiasTransposeQKV(int M, const T* input, const T* biases, T* o } } -#ifndef USE_ROCM template __global__ void AddBiasTransposeQKV(int M, const T* input, const T* biases, T* output, T* qkv_add_bias, const int rotary_embedding_dim, const int head_size, const int step, @@ -359,7 +358,6 @@ __global__ void AddBiasTransposeQKV(int M, const T* input, const T* biases, T* o } } } -#endif // this suppose 3 matrix in total template @@ -677,9 +675,7 @@ void InvokeAddBiasTranspose( assert(num_heads <= max_threads_per_block); if (do_rotary) { -#ifdef USE_ROCM - ORT_THROW("Rotary Attention is not supported on ROCm"); -#elif !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 if (format != 1 && format != 2 && format != 3) { ORT_THROW("format must be 1, 2 or 3 for rotary attention"); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index bc5f4871283bb..985d81d558716 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -787,8 +787,6 @@ Status UnfusedAttention( return result; } -#ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP - template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, int sequence_length, int total_sequence_length, @@ -859,7 +857,6 @@ template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_ cudaStream_t stream, int max_threads_per_block, AttentionData& data); -#endif template Status PastPresentBufferShare(int batch_size, int num_heads, int qk_head_size, int v_head_size, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu index 80152e918ae30..84f651ca5470d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -250,8 +250,6 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, return CUDA_CALL(cudaGetLastError()); } -#ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP - // ---------------------------------------------------------------------------------- // Below kernels are for past and present sharing buffer // ---------------------------------------------------------------------------------- @@ -397,7 +395,6 @@ template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, const BFloat16* bias, const BFloat16* qkv_buffer, BFloat16* present); -#endif // Kernel to append new and past kv in either BSNH or BNSH format // Adapted from ConcatTensorToTensor kernel in attention_kv_cache.cu file diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index ee49f362564a6..5e5f909415fff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -370,11 +370,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { return LaunchDecoderAttentionKernel( device_prop, -#ifdef USE_ROCM - GetTuningContext(), -#else UseTF32(), -#endif context->GetComputeStream(), cublas, element_size, diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 3a16f16466ed3..e7ed96d7f5ee2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -6,11 +6,7 @@ #include "fast_gelu.h" #include "core/providers/cuda/tensor/gelu_impl.h" #include "contrib_ops/cpu/bert/bias_gelu_helper.h" -#ifdef USE_ROCM -#include "contrib_ops/rocm/bert/elementwise.h" -#else #include "contrib_ops/cuda/bert/transformer_common.h" -#endif namespace onnxruntime { namespace contrib { @@ -36,10 +32,8 @@ using namespace ONNX_NAMESPACE; template FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { -#ifndef USE_ROCM const TransformerOptions* options = TransformerOptions::GetInstance(); use_half2_ = !options->DisableHalf2(); -#endif } template @@ -57,13 +51,6 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); typedef typename ToCudaType::MappedType CudaT; -#ifdef USE_ROCM - return LaunchElementwiseKernel( - GetTuningContext(), context->GetComputeStream(), - reinterpret_cast(input->Data()), static_cast(input_length), - (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, static_cast(bias_length), - reinterpret_cast(output->MutableData())); -#else return LaunchFastGeluKernel(GetDeviceProp(), Stream(context), static_cast(input_length), @@ -72,7 +59,6 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, reinterpret_cast(output->MutableData()), use_half2_); -#endif } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h index 26f3bd5a03928..3e642a70afef5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h @@ -18,9 +18,7 @@ class FastGelu final : public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; private: -#ifndef USE_ROCM bool use_half2_; -#endif }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_util.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_util.h index 320aa2a552198..9238dde012c3c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_util.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_util.h @@ -27,8 +27,6 @@ using namespace onnxruntime::cuda; namespace onnxruntime { namespace cuda { -#ifndef USE_ROCM - inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step) { const float inv_freq = t_step / pow(10000.0f, zid / (float)rot_embed_dim); return {cos(inv_freq), sin(inv_freq)}; @@ -422,7 +420,5 @@ __device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, smem[smem_pitch + transpose_idx] = vec.y; } -#endif - } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/utils.cuh b/onnxruntime/contrib_ops/cuda/bert/utils.cuh index a45664083f5c7..83c853548abda 100644 --- a/onnxruntime/contrib_ops/cuda/bert/utils.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/utils.cuh @@ -77,8 +77,6 @@ struct Float4_ { float2 y; }; -#ifndef USE_ROCM - template struct num_elems; template <> @@ -935,7 +933,5 @@ inline __device__ void ConvertFromFloat(uint4& dst, Float8_ src) { dst.w = Float2ToHalf2(src.w); } -#endif - } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc index feb6613690c08..f421a0db5a2f9 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc @@ -259,7 +259,6 @@ Status AllReduce::ComputeInternal(OpKernelContext* context) const { void* output_data = context->Output(0, in_shape)->MutableDataRaw(); -#ifndef USE_ROCM return FuncCustomAllReduce(nccl_, Stream(context), input_data, @@ -267,12 +266,6 @@ Status AllReduce::ComputeInternal(OpKernelContext* context) const { input_count, input_tensor->DataType(), onnxruntime::cuda::collective::IPCMemoryResourcePack::GetGlobalInstance()); -#else - ncclComm_t comm = nccl_->Comm(); - ncclDataType_t dtype = GetNcclDataType(input_tensor->DataType()); - NCCL_RETURN_IF_ERROR(ncclAllReduce(input_data, output_data, input_count, dtype, ncclSum, comm, Stream(context))); - return Status::OK(); -#endif } AllGather::AllGather(const OpKernelInfo& info) : NcclKernel(info) { @@ -428,7 +421,6 @@ Status FuncAllReduce( return Status::OK(); } -#ifndef USE_ROCM Status FuncCustomAllReduce( NcclContext* nccl, cudaStream_t stream, @@ -478,7 +470,6 @@ Status FuncCustomAllReduce( return Status::OK(); } -#endif static std::vector CalculatePermToSwapAxes( const int64_t axis, diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h index 49646637b635e..8dac48492cc12 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h @@ -6,11 +6,9 @@ #include "core/providers/cuda/cuda_kernel.h" #if defined(ORT_USE_NCCL) || defined(USE_MPI) -#ifndef USE_ROCM #include "custom_reduce_impl.h" #include "ipc_utils.h" #endif -#endif #if defined(ORT_USE_NCCL) #include @@ -107,7 +105,6 @@ Status FuncAllReduce( const Tensor* input, Tensor* output); -#ifndef USE_ROCM Status FuncCustomAllReduce( NcclContext* nccl, cudaStream_t stream, @@ -116,7 +113,6 @@ Status FuncCustomAllReduce( int64_t input_count, onnxruntime::MLDataType data_type, onnxruntime::cuda::collective::IPCMemoryResourcePack& ipc_mem_res_pack); -#endif void FuncAllGather( const NcclKernel* nccl_kernel, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 58c94f966841b..f12e6c530ff35 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -128,7 +128,6 @@ Status LaunchGroupNormKernel( bool use_silu, bool broadcast_skip, int channels_per_block) { - // tuning_ctx only used for ROCm EP. ORT_UNUSED_PARAMETER(tuning_ctx); GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, reinterpret_cast(workspace), epsilon, diff --git a/onnxruntime/contrib_ops/cuda/math/bias_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/math/bias_gelu_impl.cu index 5a13520254e6d..35bfc111c6492 100644 --- a/onnxruntime/contrib_ops/cuda/math/bias_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/math/bias_gelu_impl.cu @@ -14,11 +14,7 @@ namespace cuda { namespace { constexpr int kElementsPerThread = GridDim::maxElementsPerThread; -#ifdef USE_ROCM -constexpr int kThreadsPerBlock = 512; -#else constexpr int kThreadsPerBlock = GridDim::maxThreadsPerBlock; -#endif } // namespace diff --git a/onnxruntime/contrib_ops/cuda/math/bias_softmax.cc b/onnxruntime/contrib_ops/cuda/math/bias_softmax.cc index a95965775484d..db5e0e30af46d 100644 --- a/onnxruntime/contrib_ops/cuda/math/bias_softmax.cc +++ b/onnxruntime/contrib_ops/cuda/math/bias_softmax.cc @@ -31,12 +31,7 @@ struct DispatchBiasSoftmaxImpl { } // namespace -// MIOpen doesn't support double so ROCm kernel doesn't have double support for now. -#ifdef USE_ROCM -#define BIAS_SOFTMAX_TYPES float, MLFloat16 -#else #define BIAS_SOFTMAX_TYPES float, MLFloat16, double -#endif ONNX_OPERATOR_KERNEL_EX( BiasSoftmax, kMSDomain, 1, kCudaExecutionProvider, diff --git a/onnxruntime/contrib_ops/cuda/math/bias_softmax_impl.cu b/onnxruntime/contrib_ops/cuda/math/bias_softmax_impl.cu index 427c7fc624309..e665c35269b6f 100644 --- a/onnxruntime/contrib_ops/cuda/math/bias_softmax_impl.cu +++ b/onnxruntime/contrib_ops/cuda/math/bias_softmax_impl.cu @@ -41,11 +41,7 @@ __global__ void BiasSoftmaxWarpForward(output_t* output, const input_t* input, c constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = next_power_of_two < GPU_WARP_SIZE ? next_power_of_two : GPU_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; -#ifdef USE_ROCM - constexpr int WARP_BATCH = 1; -#else constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; -#endif // each "WARP" (<=32) processes WARP_BATCH(one of {1,2}) batches int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; @@ -137,13 +133,8 @@ Status BiasSoftmaxImpl(cudaStream_t stream, cudnnHandle_t cudnn_handle, T* outpu int warp_size = std::min(next_power_of_two, GPU_WARP_SIZE_HOST); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. -#ifdef USE_ROCM - int batches_per_warp = 1; - constexpr int threads_per_block = 256; -#else int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; constexpr int threads_per_block = 128; -#endif int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; @@ -229,7 +220,7 @@ Status BiasSoftmaxImpl(cudaStream_t stream, cudnnHandle_t cudnn_handle, T* outpu const T* input_data, const T* bias_data, int element_count, int batch_count, \ bool is_inner_broadcast, int bias_broadcast_size); -// MIOpen doesn't support double so ROCm kernel doesn't have double support for now. +// MIOpen doesn't support double for now. SPECIALIZED_BIAS_SOFTMAX_IMPL(float) SPECIALIZED_BIAS_SOFTMAX_IMPL(half) #ifdef USE_CUDA diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index bec78d081ef69..afdd25a617ce3 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -70,9 +70,7 @@ BeamSearch::BeamSearch(const OpKernelInfo& info) GenerationCudaDeviceHelper::InitBeamState, GenerationCudaDeviceHelper::CreateBeamScorer); -#ifndef USE_ROCM SetDeviceHelpers_Cuda(GenerationCudaDeviceHelper::ReorderPastState, GenerationCudaDeviceHelper::InitCacheIndir); -#endif SetDeviceHelpers_Gpt(GenerationCudaDeviceHelper::UpdateGptFeeds, GenerationCudaDeviceHelper::UpdateGptFeeds); @@ -87,12 +85,10 @@ BeamSearch::BeamSearch(const OpKernelInfo& info) SetConsoleDumper(&g_cuda_dumper); -#ifndef USE_ROCM cuda_device_prop_ = &reinterpret_cast(info.GetExecutionProvider())->GetDeviceProp(); cuda_device_arch_ = static_cast(cuda_device_prop_)->major * 100 + static_cast(cuda_device_prop_)->minor * 10; -#endif } Status BeamSearch::ComputeInternal(OpKernelContext* context) const { @@ -124,9 +120,7 @@ WhisperBeamSearch::WhisperBeamSearch(const OpKernelInfo& info) GenerationCudaDeviceHelper::InitBeamState, GenerationCudaDeviceHelper::CreateBeamScorer); -#ifndef USE_ROCM SetDeviceHelpers_Cuda(GenerationCudaDeviceHelper::ReorderPastState, GenerationCudaDeviceHelper::InitCacheIndir); -#endif SetDeviceHelpers_Gpt(GenerationCudaDeviceHelper::UpdateGptFeeds, GenerationCudaDeviceHelper::UpdateGptFeeds); @@ -141,12 +135,10 @@ WhisperBeamSearch::WhisperBeamSearch(const OpKernelInfo& info) SetConsoleDumper(&g_cuda_dumper); -#ifndef USE_ROCM cuda_device_prop_ = &reinterpret_cast(info.GetExecutionProvider())->GetDeviceProp(); cuda_device_arch_ = static_cast(cuda_device_prop_)->major * 100 + static_cast(cuda_device_prop_)->minor * 10; -#endif } Status WhisperBeamSearch::ComputeInternal(OpKernelContext* context) const { diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu index 44be2ef2375ee..dee9f9a95abcb 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu @@ -226,11 +226,9 @@ void TopKLauncherMaxK( dim3 grid(batch_size * num_beams, voc_parts); -#ifndef USE_ROCM cudaFuncSetAttribute(BeamSearchOnlineTopKStage1Kernel, cudaFuncAttributePreferredSharedMemoryCarveout, cudaSharedmemCarveoutMaxL1); -#endif // !USE_ROCM BeamSearchOnlineTopKStage1Kernel <<>>(input, K, vocab_size, (vocab_size + voc_parts - 1) / voc_parts, output_values_tmp, output_indices_tmp); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 94cc13e4b3b1c..52614a81d623f 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -1263,7 +1263,6 @@ void UpdateDecoderMaskedMultiHeadAttentionCacheIndirection(int32_t* tgt_indir_ca current_length); } -#ifndef USE_ROCM namespace { template struct TypeMapper : public V_vec_m_ {}; @@ -1278,7 +1277,6 @@ struct TypeMapper { using Type = uint4; }; } // namespace -#endif template __global__ void KeyCacheExpansionKernel(const T* input, @@ -1330,7 +1328,6 @@ void KeyCacheExpansionKernelLauncher(const T* key_cache, tpb |= (tpb >> 16); tpb++; -#ifndef USE_ROCM if ((head_size % 4) == 0) { using vec_type = typename TypeMapper::Type; const dim3 block(tpb); @@ -1348,16 +1345,13 @@ void KeyCacheExpansionKernelLauncher(const T* key_cache, max_seq_length, equiv_head_size); } else { -#endif const dim3 block(tpb); KeyCacheExpansionKernel<<>>(key_cache, key_cache_expanded, beam_width, max_seq_length, head_size); -#ifndef USE_ROCM } -#endif } template void KeyCacheExpansionKernelLauncher(const float* key_cache, @@ -1417,7 +1411,6 @@ void BufferExpansionKernelLauncher(const T* input, cudaStream_t stream) { const dim3 block(128); -#ifndef USE_ROCM if ((chunk_size % 4) == 0) { using vec_type = typename TypeMapper::Type; const dim3 grid(batch_size, beam_width, (chunk_size / 4 + block.x - 1) / block.x); @@ -1431,14 +1424,11 @@ void BufferExpansionKernelLauncher(const T* input, reinterpret_cast(output), chunk_size / 2); } else { -#endif const dim3 grid(batch_size, beam_width, (chunk_size + block.x - 1) / block.x); BufferExpansionKernel<<>>(input, output, chunk_size); -#ifndef USE_ROCM } -#endif } template void BufferExpansionKernelLauncher(const float* input, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index d20d0b4218bd3..a3781c8e6cfa3 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -21,7 +21,6 @@ #include "contrib_ops/cuda/transformers/greedy_search_top_one.h" #include "core/providers/cuda/tensor/transpose.h" -// the includes would be dummy for ROCm, we will ignore them for now #ifdef ENABLE_NVTX_PROFILE #include "core/providers/cuda/nvtx_profile.h" #include "core/providers/cuda/nvtx_profile_context.h" diff --git a/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc index cf623ab36015e..69756684b0c32 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc @@ -39,21 +39,17 @@ GreedySearch::GreedySearch(const OpKernelInfo& info) GenerationCudaDeviceHelper::InitGreedyState, GenerationCudaDeviceHelper::InitGreedyState); -#ifndef USE_ROCM SetDeviceHelpers_Cuda(GenerationCudaDeviceHelper::ReorderPastState); -#endif SetDeviceHelpers_Gpt(GenerationCudaDeviceHelper::UpdateGptFeeds, GenerationCudaDeviceHelper::UpdateGptFeeds); SetConsoleDumper(&g_cuda_dumper_greedysearch); -#ifndef USE_ROCM cuda_device_prop_ = &reinterpret_cast(info.GetExecutionProvider())->GetDeviceProp(); cuda_device_arch_ = static_cast(cuda_device_prop_)->major * 100 + static_cast(cuda_device_prop_)->minor * 10; -#endif } Status GreedySearch::ComputeInternal(OpKernelContext* context) const { diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling.cc b/onnxruntime/contrib_ops/cuda/transformers/sampling.cc index a9cbdfd324ad7..c61ef36529174 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling.cc @@ -40,21 +40,17 @@ Sampling::Sampling(const OpKernelInfo& info) GenerationCudaDeviceHelper::InitGreedyState, GenerationCudaDeviceHelper::InitGreedyState); -#ifndef USE_ROCM SetDeviceHelpers_Cuda(GenerationCudaDeviceHelper::ReorderPastState); -#endif SetDeviceHelpers_Gpt(GenerationCudaDeviceHelper::UpdateGptFeeds, GenerationCudaDeviceHelper::UpdateGptFeeds); SetConsoleDumper(&g_cuda_dumper_sampling); -#ifndef USE_ROCM gpu_device_prop_ = &reinterpret_cast(info.GetExecutionProvider())->GetDeviceProp(); gpu_device_arch_ = static_cast(gpu_device_prop_)->major * 100 + static_cast(gpu_device_prop_)->minor * 10; -#endif } Status Sampling::ComputeInternal(OpKernelContext* context) const { diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cu b/onnxruntime/contrib_ops/rocm/bert/attention.cu deleted file mode 100644 index b40fc2bf0eef8..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cu +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/attention.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh" -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" -#include "contrib_ops/rocm/bert/transformer_common.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm.h" - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -constexpr int kPastSequenceLengthInputIndex = 6; -constexpr int kPastInputIndex = 4; -constexpr int kPresentOutputIndex = 1; - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Attention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(kPastInputIndex, kPresentOutputIndex) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \ - Attention); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -template -Attention::Attention(const OpKernelInfo& info) - : RocmKernel(info), AttentionBase(info, true), attn_type_(kAttention) { - using HipT = typename ToHipType::MappedType; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - tunable_op_ = std::make_shared(); -} - -template -Status Attention::ComputeInternal(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* weights = context->Input(1); - const Tensor* bias = context->Input(2); - const Tensor* mask_index = context->Input(3); - const Tensor* past = context->Input(4); - const Tensor* attention_bias = context->Input(5); - const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); - - auto& device_prop = GetDeviceProp(); - RocmAttentionParameters attn; - ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), - weights->Shape(), - bias->Shape(), - mask_index, - past, - attention_bias, - &attn, - device_prop.maxThreadsPerBlock, - past_seq_len)); - ORT_ENFORCE(attn.sequence_length == attn.kv_sequence_length); // self attention - ORT_ENFORCE(attn.qkv_format == Q_K_V_BNSH); // non-packed, permuted - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(attn.batch_size); - output_shape[1] = static_cast(attn.sequence_length); - output_shape[2] = static_cast(attn.v_hidden_size); - Tensor* output = context->Output(0, output_shape); - - std::vector present_dims{ - 2, attn.batch_size, attn.num_heads, - past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, - attn.head_size}; - TensorShape present_shape(present_dims); - Tensor* present = context->Output(kPresentOutputIndex, present_shape); - - auto stream = Stream(context); - hipblasHandle_t hipblas = GetHipblasHandle(context); - - using HipT = typename ToHipType::MappedType; - using QkvProjectGeneric = GemmPermuteGenericPipeline; - using AttentionGeneric = GemmSoftmaxGemmPermuteGenericPipeline; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - - ORT_RETURN_IF_ERROR(ClassifyAttentionMode(attn_type_, &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present})); - ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE || - attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE || - attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE || - attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE || - attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE); - - size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn); - size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn), - AttentionGeneric::GetWorkspaceNumBytes(&attn)); - if (GetTuningContext()->IsTunableOpEnabled()) { - shared_workspace_bytes = std::max(shared_workspace_bytes, AttentionTunableOp::GetWorkspaceNumBytes(&attn)); - } - - auto qkv_project_output = GetScratchBuffer(qkv_project_output_bytes, context->GetComputeStream()); - auto workspace = GetScratchBuffer(shared_workspace_bytes, context->GetComputeStream()); - - GemmPermuteParams gemm_permute_params; - { - auto& params = gemm_permute_params; - params.tuning_ctx = GetTuningContext(); - params.stream = context->GetComputeStream(); - params.handle = hipblas; - params.attention = &attn; - params.device_prop = &device_prop; - - params.input_buffer = reinterpret_cast(input->DataRaw()); - params.weight_buffer = reinterpret_cast(weights->DataRaw()); - params.bias_buffer = reinterpret_cast(bias->DataRaw()); - params.out_buffer = reinterpret_cast(qkv_project_output.get()); - params.ones = GetConstOnes(attn.batch_size * attn.sequence_length, stream); - params.workspace_buffer = reinterpret_cast(workspace.get()); - } - - ORT_RETURN_IF_ERROR(QkvProjectGeneric::Run(&gemm_permute_params)); - auto [q_buffer, k_buffer, v_buffer] = QkvProjectGeneric::UnspliceOutputQKV(&gemm_permute_params); - - // NOTE: GemmPermute always output 3BNSH, k_buffer and v_buffer can be treated as 2BNSH - if (nullptr != present) { - Strides dst_strides; // the output buffer is present Tensor, the buffer is the same - - int4 add_shape{2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size}; - HipT* add_dest = nullptr; // destination of concatenated data to present - const HipT* const add_src = k_buffer; // source of concatenated data to present - const auto add_src_strides = Strides::BNSHMemory( - 2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size); - - if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; - } else if (attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - - // We only need to copy past to present in this case. All other cases will be build the present incrementally - const int4 past_shape = {2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; - HipT* const past_dest = reinterpret_cast(present->MutableDataRaw()); - const HipT* const past_src = reinterpret_cast(past->DataRaw()); - const Strides past_src_strides = Strides::BNSHMemory( - 2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); - - ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, past_src, past_shape, past_src_strides.ForBNSHCoord(), - past_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } else if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; - } else if (attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - } - - ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, add_src, add_shape, add_src_strides.ForBNSHCoord(), - add_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - - // update pointers to present_k and present_v. TODO: switch to ConvertToOffsetedBufferViews - k_buffer = reinterpret_cast(present->MutableDataRaw()); - v_buffer = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(attn.batch_size, 0, 0, 0); - } - - // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax - const TransformerOptions* options = TransformerOptions::GetInstance(); - bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); - - GemmSoftmaxGemmPermuteParams gemm_softmax_gemm_permute_params; - { - auto& params = gemm_softmax_gemm_permute_params; - params.tuning_ctx = GetTuningContext(); - params.stream = context->GetComputeStream(); - params.handle = hipblas; - params.attention = &attn; - params.device_prop = &device_prop; - // FIXME: the params.scale seems to be different from AttentionParameters::scale; - params.scale = 1.0f / sqrt(static_cast(attn.head_size)); - // TODO: switch to ConvertToOffsetedBufferViews - params.q_buffer = q_buffer; - params.k_buffer = k_buffer; - params.v_buffer = v_buffer; - params.out_buffer = reinterpret_cast(output->MutableDataRaw()); - - if (attention_bias != nullptr) { - params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); - } - - if (mask_index != nullptr) { - params.mask_index_buffer = mask_index->Data(); - params.mask_index_dims = mask_index->Shape().AsShapeVector(); - } - - params.workspace_buffer = reinterpret_cast(workspace.get()); - } - - if (this->GetTuningContext()->IsTunableOpEnabled() && - !use_persistent_softmax) { - return (*std::static_pointer_cast(tunable_op_))(&gemm_softmax_gemm_permute_params); - } else { - return AttentionGeneric::Run(&gemm_softmax_gemm_permute_params, use_persistent_softmax); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.h b/onnxruntime/contrib_ops/rocm/bert/attention.h deleted file mode 100644 index 7204fd660a516..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_kernel.h" -#include "contrib_ops/cpu/bert/attention_base.h" -#include "contrib_ops/rocm/bert/attention_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class Attention final : public RocmKernel, public AttentionBase { - public: - Attention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - public: - AttentionType attn_type_; - - // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: - // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. - // 2. We don't want to construct the object repeatly (which is expansive) during Compute. - std::shared_ptr tunable_op_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu deleted file mode 100644 index 270a8e51daf88..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ /dev/null @@ -1,435 +0,0 @@ -/* - The implementation of this file is based on qkvToContext plugin in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Modifications: scaling is moved from masked softmax to the gemm before that. -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/cpu/bert/attention_base.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/attention_softmax.h" -#include "contrib_ops/rocm/bert/decoder_attention_impl.h" - -using namespace onnxruntime::rocm; - -namespace blas = onnxruntime::rocm::tunable::blas; - -#define CHECK_ROCM(expr) HIP_RETURN_IF_ERROR(expr) - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -static size_t AlignTo(size_t a, size_t b) { - return CeilDiv(a, b) * b; -} - -size_t GetAttentionScratchSize(size_t element_size, - int batch_size, - int num_heads, - int sequence_length, - int total_sequence_length) { - const size_t bytes = element_size * batch_size * num_heads * sequence_length * total_sequence_length; - - const size_t alignment = 256; - const size_t bytesAligned = AlignTo(bytes, alignment); - return bytesAligned; -} - -size_t GetAttentionWorkspaceSize( - size_t element_size, - int batch_size, - int num_heads, - int head_size, - int sequence_length, - int total_sequence_length) { - size_t qkv_size = element_size * 3 * batch_size * sequence_length * num_heads * head_size; - return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, - sequence_length, total_sequence_length); -} - -inline int3 Get2DMaskStrides(int total_sequence_length) { - // stride == 0 indicate broadcasting - return {total_sequence_length, 0, 1}; -} - -Status ClassifyAttentionMode( - AttentionType attn_type, - RocmAttentionParameters* attn, - const std::vector& qkv, - const std::vector& past, - const std::vector& present) { - size_t num_qkv = std::count_if(qkv.cbegin(), qkv.cend(), [](auto it) { return it != nullptr; }); - size_t num_past = std::count_if(past.cbegin(), past.cend(), [](auto it) { return it != nullptr; }); - size_t num_present = std::count_if(present.cbegin(), present.cend(), [](auto it) { return it != nullptr; }); - - auto hint = MakeString(num_qkv, " qkv inputs, ", num_past, " past inputs and ", num_present, " present inputs"); - LOGS_DEFAULT(VERBOSE) << hint; - - if (attn_type == kAttention) { - ORT_ENFORCE(num_qkv == 0); - if (num_past == 0 && num_present == 0) { - attn->mode = QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE; - return Status::OK(); - } else if (num_past == 0 && num_present == 1) { - if (attn->past_present_share_buffer == false) { - attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE; - return Status::OK(); - } else { - attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE; - return Status::OK(); - } - } else if (num_past == 1 && num_present == 1) { - if (attn->past_present_share_buffer == false) { - attn->mode = QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE; - return Status::OK(); - } else { - attn->mode = QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE; - return Status::OK(); - } - } - } else if (attn_type == kMultiHeadAttention || attn_type == kDecoderMaskedMultiHeadAttention) { - if (num_qkv == 3 && num_past == 0 && num_present == 0) { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE; - return Status::OK(); - } - } else if (num_qkv == 3 && num_past == 0 && num_present == 2) { - if (attn->past_present_share_buffer == false) { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH; - return Status::OK(); - } - } else { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH; - return Status::OK(); - } - } - } else if (num_qkv == 3 && num_past == 2 && num_present == 2) { - if (attn->past_present_share_buffer == false) { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH; - return Status::OK(); - } - } else { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH; - return Status::OK(); - } - } - } else if (num_qkv == 1 && num_past == 0 && num_present == 0) { - if (attn->qkv_format == QKV_BSN3H) { - attn->mode = BLN3H_NONE_NONE_NONE_NONE_NONE_NONE; - return Status::OK(); - } - } else if (num_qkv == 2 && num_past == 0 && num_present == 0) { - if (attn->qkv_format == Q_KV_BSNH_BSN2H) { - attn->mode = BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE; - return Status::OK(); - } - } - } - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Unsupported AttentionMode for ", attn_type, ". Got qkv format ", attn->qkv_format, - ". Got ", hint); -} - -template -Status DecoderQkvToContext( - const hipDeviceProp_t& prop, - RocmTuningContext* tuning_ctx, - Stream* ort_stream, - hipblasHandle_t& hipblas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const T* gemm_query_buffer, - const T* gemm_kv_buffer, - const bool* key_padding_mask, - const T* key_cache, - const T* value_cache, - T* qkv_buffer, - T* workspace_buffer, - T* output, - T* new_key_cache, - T* new_value_cache) { - const int max_threads_per_block = prop.maxThreadsPerBlock; - const int BN = batch_size * num_heads; - const int BHN = BN * head_size; - const int BNS = BN * sequence_length; - const int k_buffer_offset = sequence_length * BHN; - const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; - - T* temp_qkv_buffer = workspace_buffer; - auto stream = static_cast(ort_stream->GetHandle()); - - const T* q = qkv_buffer; - // transpose q and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, - num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); - - const T* k = qkv_buffer + k_buffer_offset; - const T* v = qkv_buffer + v_buffer_offset; - if (!has_layer_state || !use_past) { - if (!static_kv) { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } else { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } - } else { - if (!static_kv) { - // transpose kv and copy them to temp_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); - // concat cache-k with k and copy to qkv_buffer - if (nullptr != key_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, - batch_size, head_size, num_heads, - max_threads_per_block, 1, key_cache, - temp_qkv_buffer, qkv_buffer + k_buffer_offset)); - } - // concat cache-v with v and copy to qkv_buffer - if (nullptr != value_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, - batch_size, head_size, num_heads, - max_threads_per_block, 1, value_cache, - temp_qkv_buffer + k_buffer_offset, - qkv_buffer + v_buffer_offset)); - } - } - } - - if (has_layer_state) { - if (use_past && static_kv) { - CHECK_ROCM(hipMemcpyAsync(new_key_cache, key_cache, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - CHECK_ROCM(hipMemcpyAsync(new_value_cache, value_cache, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - } else { - CHECK_ROCM(hipMemcpyAsync(new_key_cache, k, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - CHECK_ROCM(hipMemcpyAsync(new_value_cache, v, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - } - } - - // scratch1: BxNxSxS* buffer - // scratch2: BxNxSxS* buffer - // scratch3: BxNxSxH buffer - T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; - T* scratch2 = scratch1 + BNS * kv_sequence_length; - T* scratch3 = scratch2 + BNS * kv_sequence_length; - - // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS* - // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS* - const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); - const int temp_matrix_size = sequence_length * kv_sequence_length; - - const int strideA = kv_sequence_length * head_size; - const int strideB = sequence_length * head_size; - if (use_past && static_kv) { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::Trans, blas::BlasOp::NonTrans, - kv_sequence_length, sequence_length, head_size, - /*alpha=*/rsqrt_head_size, - key_cache, head_size, strideA, - q, head_size, strideB, - /*beta=*/0.0f, - scratch1, kv_sequence_length, temp_matrix_size, - BN)); - } else { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::Trans, blas::BlasOp::NonTrans, - kv_sequence_length, sequence_length, head_size, - /*alpha=*/rsqrt_head_size, - k, head_size, strideA, - q, head_size, strideB, - /*beta=*/0.0f, - scratch1, kv_sequence_length, temp_matrix_size, - BN)); - } - - if (has_key_padding_mask) { - int3 strides = Get2DMaskStrides(kv_sequence_length); - ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( - ort_stream, kv_sequence_length, sequence_length, batch_size, num_heads, - strides, nullptr, key_padding_mask, nullptr, scratch1, scratch2, - false, 1.0f, false, nullptr, mask_filter_value)); - } else { - ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, scratch1, scratch2, false)); - } - - // compute P*V (as V*P), and store in scratch3: BxNxSxH - if (use_past && static_kv) { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - head_size, sequence_length, kv_sequence_length, - /*alpha=*/1.0f, - value_cache, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - /*beta=*/0.0f, - scratch3, head_size, strideB, - BN)); - } else { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - head_size, sequence_length, kv_sequence_length, - /*alpha=*/1.0f, - v, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - /*beta=*/0.0f, - scratch3, head_size, strideB, - BN)); - } - - // scratch3 is BxNxSxH, transpose to output SxBxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, - num_heads, max_threads_per_block, true, scratch3, output); -} - -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, - RocmTuningContext* tuning_ctx, - Stream* stream, - hipblasHandle_t& hipblas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const void* gemm_query_buffer, - const void* gemm_kv_buffer, - const bool* key_padding_mask, - const void* key_cache, - const void* value_cache, - void* qkv_buffer, - void* workspace_buffer, - void* output, - void* new_key_cache, - void* new_value_cache) { - if (element_size == 2) { - return DecoderQkvToContext( - prop, - tuning_ctx, - stream, - hipblas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } else { - return DecoderQkvToContext( - prop, - tuning_ctx, - stream, - hipblas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h deleted file mode 100644 index 07d875e90fa4b..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "contrib_ops/cpu/bert/attention_common.h" -#include "contrib_ops/cpu/bert/attention_parameters.h" -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -typedef struct __align__(32) { - long long int x, y, z, w; -} LongLong4; - -size_t GetAttentionScratchSize( - size_t element_size, - int batch_size, - int num_heads, - int sequence_length, - int all_sequence_length); - -size_t GetAttentionWorkspaceSize( - size_t element_size, - int batch_size, - int num_heads, - int head_size, - int sequence_length, - int past_sequence_length); - -Status LaunchTransCtx(hipStream_t stream, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const float* input, float* output); - -Status LaunchTransCtx(hipStream_t stream, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const half* input, half* output); - -Status LaunchTransQkv(hipStream_t stream, const int matrix_num, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const float* input, float* output, - int total_matrix_count = -1); - -Status LaunchTransQkv(hipStream_t stream, const int matrix_num, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const half* input, half* output, - int total_matrix_count = -1); - -Status LaunchConcatTensorToTensor(hipStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const float* tensor_in, - const float* tensor_add, - float* tensor_out); - -Status LaunchConcatTensorToTensor(hipStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const half* tensor_in, - const half* tensor_add, - half* tensor_out); - -inline hipblasStatus_t _compat_hipblas_gemm_strided_batched_ex(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const void* alpha, - const void* A, - hipDataType a_type, - int lda, - hipblasStride stride_A, - const void* b, - hipDataType b_type, - int ldb, - hipblasStride stride_b, - const void* beta, - void* c, - hipDataType c_type, - int ldc, - hipblasStride stride_c, - int batch_count, - hipblasComputeType_t compute_type, - hipblasGemmAlgo_t algo) { - return hipblasGemmStridedBatchedEx(handle, - transa, - transb, - m, // m - n, // n - k, // k - alpha, // alpha - A, // A - a_type, // A type - lda, // lda - stride_A, // strideA - b, // B - b_type, // B type - ldb, // ldb - stride_b, // strideB - beta, // beta - c, // C - c_type, // C type - ldc, // ldc - stride_c, // strideC - batch_count, // batch count - compute_type, - algo); -} - -// Compatible for CublasMathModeSetter -class CompatHipblasMathModeSetter { - public: - CompatHipblasMathModeSetter(const hipDeviceProp_t&, - hipblasHandle_t, - int) { - } -}; - -enum AttentionMode { - // Q,K,V,PastK,PastV,PresentK,PresentV - QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE, - QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE, - QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE, - QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE, - QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE, - BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE, - BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE, - BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH, - BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH, - BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH, - BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH, - BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH, - BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH, - BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH, - BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH, - BLN3H_NONE_NONE_NONE_NONE_NONE_NONE, - BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE, -}; - -struct RocmAttentionParameters : AttentionParameters { - AttentionMode mode; -}; - -Status ClassifyAttentionMode(AttentionType type, - RocmAttentionParameters* attn, - const std::vector& qkv, - const std::vector& past, - const std::vector& present); - -template -Status LaunchStridedCopy( - hipStream_t stream, - const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) - T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) - int max_threads_per_block); - -template -Status LaunchStridedCopy(hipStream_t stream, - const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h) - T* out, LongLong4 out_strides, // coord (b,n,s,h) - int max_threads_per_block); -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h deleted file mode 100644 index 9f2faa228cf79..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h +++ /dev/null @@ -1,465 +0,0 @@ -#include "hip/hip_runtime.h" -/* - The implementation of this file is based on qkvToContext plugin in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#pragma once - -#include -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/math/softmax.h" - -#define ROCMRT_INF_F __int_as_float(0x7f800000) - -using namespace onnxruntime::rocm; -using namespace hipcub; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -__device__ inline void Softmax(const int all_sequence_length, - const int valid_end, - const int valid_start, - const T* attn_bias, - const T* input, - T* output) { - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - float thread_data_max(-ROCMRT_INF_F); - - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - for (int i = threadIdx.x; i < valid_end; i += TPB) { - if (i >= valid_start) { - const int index = offset + i; - float input_at_idx = attn_bias == nullptr - ? static_cast(input[index]) - : static_cast(input[index] + attn_bias[index]); - if (thread_data_max < input_at_idx) { - thread_data_max = input_at_idx; - } - } - } - - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max()); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - float thread_data_sum(0.f); - for (int i = threadIdx.x; i < valid_end; i += TPB) { - if (i >= valid_start) { - const int index = offset + i; - float val = attn_bias == nullptr ? input[index] : input[index] + attn_bias[index]; - thread_data_sum += expf(val - max_block); - } - } - - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, hipcub::Sum()); - if (threadIdx.x == 0) { - sum_reverse_block = 1.f / sum; - } - __syncthreads(); - - for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { - const int index = offset + i; - float input_at_idx = attn_bias == nullptr - ? static_cast(input[index]) - : static_cast(input[index] + attn_bias[index]); - const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f; - output[index] = T(val); - } -} - -template -__device__ inline void SoftmaxSmall(const int all_sequence_length, - const int sequence_length, - const int valid_end, - const int valid_start, - const T* attn_bias, - const T* input, - T* output, - bool causal) { - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - const int index = offset + threadIdx.x; - - bool is_valid = false; // whether it has attention mask == 1. - - // Update end position for causal. - int end = valid_end; - if (causal) { - const int end_causal = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; - if (end_causal < end) { - end = end_causal; - } - } - - is_valid = (threadIdx.x >= valid_start && threadIdx.x < end); - - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - float input_data = attn_bias == nullptr - ? static_cast(input[index]) - : static_cast(input[index] + attn_bias[index]); - float thread_data_max = is_valid ? input_data : float(-ROCMRT_INF_F); - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max(), end); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - float thread_data_exp(0.f); - if (is_valid) { - thread_data_exp = expf(input_data - max_block); - } - - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), end); - - // Store value of 1.0/sum. - if (threadIdx.x == 0) { - sum_reverse_block = (1.f) / sum; - } - __syncthreads(); - - // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. - if (threadIdx.x < all_sequence_length) { - output[index] = is_valid ? T(thread_data_exp * sum_reverse_block) : T(0.f); - } -} - -// Note about the attention_mask_strides and attention_mask/key_padding_mask -// attention_mask accepts 2D, 3D or 4D tensor, but it will be viewed as 3D tensor uniformally and it will be indexed -// as [batch_index, sequence_index, token_index]. -template -__global__ void SoftmaxWithRawMaskSmallKernel( - const int all_sequence_length, - const int sequence_length, - const int3 attention_mask_strides, - const int* attention_mask, // 2D, 3D or 4D attention mask - const bool* key_padding_mask, - const T* attn_bias, - const T* input, - T* output, - const bool causal, - const float rsqrt_head_size, - const bool skip_softmax, - const float mask_filter_value) { - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; - - // Mask all thread_data values to negative infinity to allow BlockReduce Max operation over all thread_data - // members with all invalid members set to a value that does not impact the final result. This is necessary - // to avoid the performance impact from using the valid_items interface. - float thread_data = -ROCMRT_INF_F; - if (threadIdx.x < all_sequence_length) { - thread_data = float(input[index]) * rsqrt_head_size; - - const int sequence_index = blockIdx.x % sequence_length; - if (causal) { - int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length. - if (threadIdx.x > from_index) { - thread_data = -ROCMRT_INF_F; - } - } - - const int batch_index = blockIdx.y; - int mask_offset = attention_mask_strides.x * batch_index + - attention_mask_strides.y * sequence_index + - attention_mask_strides.z * threadIdx.x; - - if (nullptr == key_padding_mask) { - const int& mask = attention_mask[mask_offset]; - if (mask == 0) - thread_data += mask_filter_value; - } else { - const bool mask = key_padding_mask[mask_offset]; - if (mask) { - thread_data = -ROCMRT_INF_F; - } - } - - if (attn_bias != nullptr) { - thread_data += float(attn_bias[index]); - } - } - - if (skip_softmax) { - if (threadIdx.x < all_sequence_length) { - output[index] = T(thread_data); - } - return; - } - - const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max()); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - // Mask all thread_data_exp values to zero to allow BlockReduce Sum operation over all thread_data_exp - // members with all invalid members set to a value that does not impact the final result. This is necessary - // to avoid the performance impact from using the valid_items interface. - float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum()); - - // Store value of 1.0/sum - if (threadIdx.x == 0) { - sum_reverse_block = (1.f) / sum; - } - __syncthreads(); - - if (threadIdx.x < all_sequence_length) { - output[index] = T(thread_data_exp * sum_reverse_block); - } -} - -template -__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, - const T* attn_bias, const T* input, T* output, bool causal) { - SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, - attn_bias, input, output, causal); -} - -template -__global__ void SoftmaxKernel(const int all_sequence_length, const T* attn_bias, const T* input, T* output) { - Softmax(all_sequence_length, all_sequence_length, 0, attn_bias, input, output); -} - -template -Status ComputeSoftmax( - hipStream_t stream, - const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const T* attn_bias, const T* input, T* output, bool causal) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - if (all_sequence_length <= 32) { - const int blockSize = 32; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 64) { - const int blockSize = 64; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 128) { - const int blockSize = 128; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 256) { - const int blockSize = 256; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 512) { - const int blockSize = 512; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (!causal) { - const int blockSize = 1024; - SoftmaxKernel<<>>( - all_sequence_length, attn_bias, input, output); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); - } - - return HIP_CALL(hipPeekAtLastError()); -} - -template -__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, - const int* mask_end, const int* mask_start, - const T* attn_bias, const T* input, T* output, - bool causal) { - __shared__ int start_position; - __shared__ int end_position; - - if (threadIdx.x == 0) { - const int batch = blockIdx.y; - start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); - - // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. - if (start_position >= end_position) { - start_position = 0; - end_position = all_sequence_length; - } - } - __syncthreads(); - - SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, - attn_bias, input, output, causal); -} - -template -__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int* mask_end, const int* mask_start, - const T* attn_bias, const T* input, T* output) { - __shared__ int start_position; - __shared__ int end_position; - - if (threadIdx.x == 0) { - const int batch = blockIdx.y; - start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); - - // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. - if (start_position >= end_position) { - start_position = 0; - end_position = all_sequence_length; - } - } - __syncthreads(); - - Softmax(all_sequence_length, end_position, start_position, attn_bias, input, output); -} - -template -Status ComputeSoftmaxWithMask1D( - hipStream_t stream, - const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const int* mask_index, const int* mask_start, - const T* attn_bias, const T* input, T* output, const bool causal) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - -#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ - MaskedSoftmaxKernelSmall<<>>( \ - all_sequence_length, sequence_length, mask_index, mask_start, \ - attn_bias, input, output, causal); - - if (all_sequence_length <= 32) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); - } else if (all_sequence_length <= 64) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); - } else if (all_sequence_length <= 128) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); - } else if (all_sequence_length <= 256) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); - } else if (all_sequence_length <= 512) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); - } else if (all_sequence_length <= 1024) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); - } else if (!causal) { - const int blockSize = 1024; - MaskedSoftmaxKernel<<>>( - all_sequence_length, mask_index, mask_start, - attn_bias, input, output); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); - } - -#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE - - return HIP_CALL(hipPeekAtLastError()); -} - -template -Status ComputeSoftmaxWithRawMask(Stream* ort_stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int num_heads, - const int3 attention_mask_strides, - const int* attention_mask, - const bool* key_padding_mask, - const T* attn_bias, - const T* input, - T* output, - const bool causal, - const float rsqrt_head_size, - const bool use_persistent_softmax, - T* persistent_softmax_workspace, - const float mask_filter_value) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - - T* out = use_persistent_softmax ? persistent_softmax_workspace : output; - auto stream = static_cast(ort_stream->GetHandle()); - -#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ - SoftmaxWithRawMaskSmallKernel<<>>( \ - all_sequence_length, sequence_length, attention_mask_strides, \ - attention_mask, key_padding_mask, attn_bias, input, out, \ - causal, rsqrt_head_size, \ - use_persistent_softmax, mask_filter_value); - - if (all_sequence_length <= 32) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); - } else if (all_sequence_length <= 64) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); - } else if (all_sequence_length <= 128) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); - } else if (all_sequence_length <= 256) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); - } else if (all_sequence_length <= 512) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); - } else if (all_sequence_length <= 1024) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); - } - -#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE - - if (use_persistent_softmax) { - return dispatch_warpwise_softmax_forward(ort_stream, - output, - persistent_softmax_workspace, - all_sequence_length, - all_sequence_length, - batch_size * num_heads * sequence_length); - } - - return HIP_CALL(hipPeekAtLastError()); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh deleted file mode 100644 index 213940f132963..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/rocm_kernel.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/cpu/bert/attention_common.h" -#include "contrib_ops/cpu/bert/attention_parameters.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -namespace blas = onnxruntime::rocm::tunable::blas; - -namespace { -std::tuple GetQkvProjectGemmMNKBatch(const AttentionParameters* attention) { - int m = attention->sequence_length; - int n = (attention->hidden_size + attention->hidden_size + attention->v_hidden_size); // q + k + v - int k = attention->input_hidden_size; - int batch = attention->batch_size; - return {m, n, k, batch}; -} -} // namespace - -template -struct GemmPermuteParams : onnxruntime::rocm::tunable::OpParams { - std::string Signature() const override { - auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(attention); - return MakeString("M", m, "_N", n, "_K", k, "_B", batch); - } - - hipblasHandle_t handle; - const AttentionParameters* attention; - const hipDeviceProp_t* device_prop; - - const T* input_buffer; - const T* weight_buffer; - const T* bias_buffer; - T* out_buffer; - - int3 bias_strides; - - const T* ones; // used for broadcasting bias if the underlying algorithm does not support strides - T* workspace_buffer; -}; - -template -struct GemmPermuteGenericPipeline { - inline static size_t GetOutputNumBytes(const AttentionParameters* attn) { - auto [m, n, _, batch] = GetQkvProjectGemmMNKBatch(attn); - return sizeof(T) * m * n * batch; - } - - inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) { - return GetOutputNumBytes(attn); - } - - inline static std::tuple GetGemmMNK(const GemmPermuteParams* params) { - auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(params->attention); - return {batch * m, n, k}; - } - - inline static std::tuple UnspliceOutputQKV(const GemmPermuteParams* params) { - auto* attn = params->attention; - int64_t batch = attn->batch_size * attn->num_heads; - int64_t num_elems_per_batch = attn->sequence_length * attn->head_size; - int64_t num_elems = batch * num_elems_per_batch; - auto q = params->out_buffer + 0 * num_elems; - auto k = params->out_buffer + 1 * num_elems; - auto v = params->out_buffer + 2 * num_elems; - return {q, k, v}; - } - - inline static Status BroadcastBias(const GemmPermuteParams* params) { - auto [m, n, k] = GetGemmMNK(params); - // Bias shape is (N), broadcast using B(M, N) = ones(M, 1) x bias(1, N). - // TODO: use custom kernel of expand to improve the performance. - return blas::row_major::Gemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - m, n, 1, - /*alpha=*/1.0f, - params->ones, 1, - params->bias_buffer, n, - /*beta=*/0.0f, - params->workspace_buffer, n); - } - - inline static Status Gemm(const GemmPermuteParams* params) { - auto [m, n, k] = GetGemmMNK(params); - // result(M, N) = input x weights + bias. - return blas::row_major::Gemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - m, n, k, - /*alpha=*/1.0f, - params->input_buffer, k, - params->weight_buffer, n, - /*beta=*/1.0f, - params->workspace_buffer, n); - } - - inline static Status Permute0213(const GemmPermuteParams* params) { - auto* attn = params->attention; - // input should be BxSx3xNxH => gemm_buffer: 3xBxNxSxH - return LaunchTransQkv( - params->StreamHandle(), 3, attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, - params->device_prop->maxThreadsPerBlock, false, params->workspace_buffer, params->out_buffer); - } - - static Status Run(const GemmPermuteParams* params) { - ORT_RETURN_IF_ERROR(BroadcastBias(params)); - ORT_RETURN_IF_ERROR(Gemm(params)); - ORT_RETURN_IF_ERROR(Permute0213(params)); - return Status::OK(); - } -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh deleted file mode 100644 index be8508670e4b1..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh +++ /dev/null @@ -1,177 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#ifdef USE_COMPOSABLE_KERNEL -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/utility/data_type.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using F16 = ck::half_t; -using F32 = float; - -template -using S = ck::Sequence; - -using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecialization; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface -using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; - -static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; - -template -using device_batched_gemm_softmax_gemm_permute_instances = - std::tuple< - // clang-format off - // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| D0s Bias| - // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | SrcScalar| - // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | PerVector| - // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, -#if ROCM_VERSION >= 50500 - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, -#endif - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, - // Padded fallback kernel - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> - // clang-format on - >; - -struct PreSoftmaxAttentionScoreOp { - PreSoftmaxAttentionScoreOp(float scale) : scale_(scale) {} - - // non-biased, non-masked - __host__ __device__ void operator()(float& y, const float& x) const { - y = scale_ * x; - } - - // biased or converted masked - __host__ __device__ void operator()(float& y, const float& x, const F16& bias) const { - y = scale_ * x + ck::type_convert(bias); - } - - // biased and converted masked - __host__ __device__ void operator()(float& y, const float& x, const F16& bias, const F16& converted_mask) const { - y = scale_ * x + ck::type_convert(bias) + ck::type_convert(converted_mask); - } - - float scale_; -}; - -// Use this function to gat implementation -template -std::vector, - PassThrough, PassThrough, D0Op, PassThrough, PassThrough, - MaskingSpec>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances() { - return {}; -} - -// implemented in impl_{fp16,bf16}[_biased][_masked].cu -// fp16, non-biased, non-masked -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); - -// fp16, biased, non-masked -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); - -// fp16, biased, fp16 masked, basically, two bias -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); - -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); - -// fp16, biased, non-masked -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); - -// fp16, biased, fp16 masked, basically, two bias -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu deleted file mode 100644 index 2e32a6594d164..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using NonBiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskDisabled>{}); - - return instances; -} - -using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskOutUpperTriangle>{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu deleted file mode 100644 index 91da8d9e1f9a8..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using BiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskDisabled>{}); - - return instances; -} - -using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskOutUpperTriangle>{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu deleted file mode 100644 index b08123be18977..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using BiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskDisabled>{}); - - return instances; -} - -using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskOutUpperTriangle>{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh deleted file mode 100644 index 226b89cfb2b86..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ /dev/null @@ -1,915 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -/* About Computing in these Pipelines - -B: batch size of Attention Op. NOTE: To be disambiguated with batch size of GEMMs -S: sequence length -T: total sequence length -N: num of heads -H: head dimension - -The following use qkv_format == Q_K_V_BNSH (mode == BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE) as a example: - -BN: B*N, which is the batch size of GEMMs. NOTE: To be disambiguated with batch size of Attention Op - -In QKV projection (prior to this pipeline): - /-> Q [B,S,N*H] ->Reshape-> [B,S,N,H] ->Permute0213-> [B,N,S,H] -X --o--> K [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] - \-> V [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] - -pre_softmax_attn_scores = Q*K' = [B,N,S,H] * [BxNxTxH]' = [B,N,S,T] Batched GEMM1 -pre_softmax_attn_scores_masked = pre_softmax_attn_scores * scale +? bias +? mask Scale Add Bias, +? is optional -attn_scores = softmax(pre_softmax_attn_scores_masked) = [B,N,S,T] Softmax -scaled_multi_head_attn = attn_scores * V = [B,N,S,T] * [B,N,T,H] = [B,N,S,H] Batched GEMM2 - -Op outputs scaled_multi_head_attn: -[B,N,S,H] ->Permute0213-> [B,S,N,H] ->Reshape-> [B,S,N*H] - - -For the computing of pre_softmax_attn_scores +? mask +? bias: - -GemmSoftmaxGemmPermuteGenericPipeline handles it in specialized softmax. TODO: remove it! - -CK in GemmSoftmaxGemmPermuteTunablePipeline - - Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked - bias --------------> [B,N,S,T] --+?--/ -mask_2d ---> [B,T] ---> [B,1,1,T] -/ - - Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked - bias --------------> [B,N,S,T] --+?--/ -mask_3d --> [B,S,T] --> [B,1,S,T] -/ - - Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked - bias --------------> [B,N,S,T] --+?--/ -mask_4d -> [B,1,M,M] -> [B,1,S,T] -/ M is max_sequence_length from megatron, we will create a - **sub-view** from original mask buffer - -For CK implementation, there will be four cases combined: -non-biased, non-masked, no special processing. - biased, non-masked, no special processing, add the mask directly. -non-biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast add with scaled Q*K'. - biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast add with bias and scaled Q*K'. - -Broadcast add is not actually perform the broadcasting, just broadcast the load operation from memory. The impl details -are in composable kernels. The scale and add logic is performed via Acc0ElementOp - -# Classified modes: - -| Q | K | V | past(K)| pastV | present(K)| presentV | Op, desc -| ---- | ---- | ---- | ------ | ----- | --------- | -------- | --------- -| QFMT | KFMT | VFMT | - | - | - | - | A, basic, qkv is impl dependent by qkv_format -| QFMT | KFMT | VFMT | 2BNPH | - | 2BNTH *^ | - | A, past_present_share_buffer = false, qkv is impl dependent by qkv_format -| QFMT | KFMT | VFMT | 2BNMH | - | 2BNMH *^ | - | A, past_present_share_buffer = true, qkv is impl dependent by qkv_format -| BSNH | BLNH*| BLNH^| - | - | - | - | MHA basic -| BSNH | BNLH*| BNLH^| - | - | - | - | MHA cross, pass_past_in_kv = true -| BSNH | - | - | - | - | BNLH * | BNLH ^ | MHA cross, pass_past_in_kv = false -| BSNH | BLNH | BLNH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false -| BSNH | BNLH | BNLH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false -| BSNH | BLNH | BLNH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true -| BSNH | BNLH | BNLH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true -| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false -| BSNH | BNLH | BNLH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false -| BSNH | BLNH | BLNH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true -| BSNH | BNLH | BNLH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true -| BLN3H*^| - | - | - | - | - | - | MHA basic, qkv_packed -| BSNH | BLN2H*^| - | - | - | - | - | MHA basic, kv_packed - -Q, K, V, past(K), pastV, present(K), presentV is the Input of the contrib OpKernel - -About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_v for v_buffer - -- Marked with `*` indicate the Tensor is used for k_buffer passing. -- Marked with `^` indicate the Tensor is used for v_buffer passing. - -# Supported Op - -- A: Attention -- MHA: MultiHeadAttention - -# Dim Value - -- B: batch_size -- N: num_heads -- H: head_size - -- S: sequence_length -- L: kv_sequence_length -- P: past_sequence_length -- T: total_sequence_length = P + L -- M: max_sequence_length -*/ - -#include "core/framework/tensor_shape.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/cpu/bert/attention_base.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/attention_softmax.h" -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif // USE_COMPOSABLE_KERNEL - -#include -#include - -namespace blas = onnxruntime::rocm::tunable::blas; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -inline int3 Get2DMaskStrides(int total_sequence_length) { - // stride == 0 indicate broadcasting - return {total_sequence_length, 0, 1}; -} - -// A stride maps from natural coordinate to physical offset of underlying memory storage buffer offset. We need to -// specify both of the natural coordinate order, say (b,n,s,h), (b,s,n,h) or (b,n,h,s), and memory order, say BNSH or -// BSNH, to determain the strides. To obtain the offset, we just do the inner product of coordinate with the strides. -// This wrapper create the stride vector from the physical dimension (or physical shape). -struct Strides { - // Create the strides for BNSH physically indexed memory buffer - static Strides BNSHMemory(int batch_dim, - int num_head_dim, - int seqlen_dim, - int head_size_dim) { - ORT_UNUSED_PARAMETER(batch_dim); - return Strides{LongLong4{ - static_cast(num_head_dim) * seqlen_dim * head_size_dim, - static_cast(seqlen_dim) * head_size_dim, - static_cast(head_size_dim), - static_cast(1), - }}; - } - - // Create the strides for BSNH physically indexed memory buffer - static Strides BSNHMemory(int batch_dim, - int seqlen_dim, - int num_head_dim, - int head_size_dim) { - ORT_UNUSED_PARAMETER(batch_dim); - return Strides{LongLong4{ - static_cast(seqlen_dim) * num_head_dim * head_size_dim, - static_cast(head_size_dim), - static_cast(num_head_dim) * head_size_dim, - static_cast(1), - }}; - } - - template - T ForBNSHCoord() const { - using E = typename T::value_type; - return T{static_cast(strides_for_bnsh_coord.x), - static_cast(strides_for_bnsh_coord.y), - static_cast(strides_for_bnsh_coord.z), - static_cast(strides_for_bnsh_coord.w)}; - } - - template - T ForBSNHCoord() const { - using E = typename T::value_type; - return T{static_cast(strides_for_bnsh_coord.x), - static_cast(strides_for_bnsh_coord.z), - static_cast(strides_for_bnsh_coord.y), - static_cast(strides_for_bnsh_coord.w)}; - } - - template - T ForBNHSCoord() const { - using E = typename T::value_type; - return T{static_cast(strides_for_bnsh_coord.x), - static_cast(strides_for_bnsh_coord.y), - static_cast(strides_for_bnsh_coord.w), - static_cast(strides_for_bnsh_coord.z)}; - } - - int64_t OffsetAt(int b, int n, int s, int h) const { - return strides_for_bnsh_coord.x * b + strides_for_bnsh_coord.y * n + - strides_for_bnsh_coord.z * s + strides_for_bnsh_coord.w * h; - } - - // store intermediate strides in the canonical (b,n,s,h) coordinate order - LongLong4 strides_for_bnsh_coord; -}; - -template -std::tuple ConvertToOffsetedBufferViews( - const RocmAttentionParameters* attn, - const T* query = nullptr, // q or packed_qkv - const T* key = nullptr, // k or packed kv - const T* value = nullptr, // - const T* present = nullptr, // present or present_k - const T* present_v = nullptr) { - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: { - return {reinterpret_cast(query), - reinterpret_cast(key), - reinterpret_cast(value)}; - } - case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: - case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: { - auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->total_sequence_length * - attn->head_size; - return {reinterpret_cast(query), - reinterpret_cast(present), - reinterpret_cast(present) + offset}; - } - case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: - case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: { - auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->max_sequence_length * - attn->head_size; - return {reinterpret_cast(query), - reinterpret_cast(present), - reinterpret_cast(present) + offset}; - } - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - return {reinterpret_cast(query), - reinterpret_cast(present), - reinterpret_cast(present_v)}; - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: { - auto packed_kv = reinterpret_cast(key); - return {reinterpret_cast(query), packed_kv, packed_kv + attn->head_size}; - } - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: { - auto packed_qkv = reinterpret_cast(query); - return {packed_qkv, packed_qkv + 1 * attn->head_size, packed_qkv + 2 * attn->head_size}; - } - default: - ORT_ENFORCE("unreachable"); - return {}; - } -} - -inline std::tuple GetQkvStrides(const RocmAttentionParameters* attn) { - // G0 not used, because it is the slowest dimension - const int& B = attn->batch_size; - const int& N = attn->num_heads; - const int& S = attn->sequence_length; - const int& L = attn->kv_sequence_length; - const int& T = attn->total_sequence_length; - const int& M = attn->max_sequence_length; - const int& H = attn->head_size; - const int& Hv = attn->v_head_size; - - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - if (attn->qkv_format == Q_K_V_BNSH) { - return { - Strides::BNSHMemory(B, N, S, H), - Strides::BNSHMemory(B, N, L, H), - Strides::BNSHMemory(B, N, L, Hv), - }; - } else if (attn->qkv_format == Q_K_V_BSNH) { - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BSNHMemory(B, L, N, H), - Strides::BSNHMemory(B, L, N, Hv), - }; - } - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BNSHMemory(B, N, T, H), - Strides::BNSHMemory(B, N, T, Hv), - }; - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BNSHMemory(B, N, M, H), - Strides::BNSHMemory(B, N, M, Hv), - }; - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BSNHMemory(B, L, N, H), - Strides::BSNHMemory(B, L, N, Hv), - }; - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BNSHMemory(B, N, L, H), - Strides::BNSHMemory(B, N, L, Hv), - }; - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BSNHMemory(B, L, N, 2 * H), - Strides::BSNHMemory(B, L, N, 2 * Hv), - }; - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, L, N, 3 * H), - Strides::BSNHMemory(B, L, N, 3 * H), - Strides::BSNHMemory(B, L, N, 3 * Hv), - }; - default: - ORT_ENFORCE("unreachable"); - return {}; - } -} - -inline std::tuple GetRawMaskBufferAddrSizesAndStrides( - const int* buffer, const RocmAttentionParameters* attn) { - const int* offseted_buffer{buffer}; // how to view the mask buffer - int3 sizes{0, 0, 0}; // the logical shape of the view - int3 strides{-1, -1, -1}; // the physical memory layout - switch (attn->mask_type) { - case MASK_NONE: - case MASK_2D_DUMMY: - break; // No mask - case MASK_2D_KEY_PADDING: - sizes = {attn->batch_size, 1, attn->total_sequence_length}; - strides = Get2DMaskStrides(attn->total_sequence_length); - break; - case MASK_3D_ATTENTION: - sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; - strides = {attn->sequence_length * attn->total_sequence_length, attn->total_sequence_length, 1}; - break; - case MASK_4D_MEGATRON: - // offset to skip past sequence part, so that we can index it with [batch_index, sequence_index, token_index] - offseted_buffer = buffer + attn->past_sequence_length * attn->max_sequence_length; - sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; - strides = {attn->max_sequence_length * attn->max_sequence_length, attn->max_sequence_length, 1}; - break; - default: - LOGS_DEFAULT(FATAL) << "unsupported mask type: " << attn->mask_type; - throw std::runtime_error("unsupported mask type"); - } - return {offseted_buffer, sizes, strides}; -} - -template -struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { - std::string Signature() const override { - return MakeString( - "B", attention->batch_size, - "_S", attention->sequence_length, - "_T", attention->total_sequence_length, - "_N", attention->num_heads, - "_H", attention->head_size, - "_Hv", attention->v_head_size, - bias_buffer != nullptr ? "_B" : "_NB", - "_M", mask_index_dims.size(), - "_QKV", attention->qkv_format, - "_MODE", attention->mode); - } - - std::tuple GetGemmsMNKOBatch() const { - ORT_ENFORCE(attention != nullptr); - auto m = attention->sequence_length; - auto n = attention->total_sequence_length; - auto k = attention->head_size; - auto o = attention->v_head_size; - auto batch = attention->batch_size * attention->num_heads; - return {m, n, k, o, batch}; - } - - hipblasHandle_t handle; - const RocmAttentionParameters* attention; - const hipDeviceProp_t* device_prop; - - float scale; - const T* q_buffer; - const T* k_buffer; - const T* v_buffer; - T* out_buffer; - - // optional, attention bias [B,N,S,T] - // TODO: support shape [B,1,S,T], [1, N, S, T], [1, 1, S, T] with broadcast. - const T* bias_buffer{nullptr}; - - // optional, mask value - const int* mask_index_buffer{nullptr}; - TensorShapeVector mask_index_dims{}; - - // optional, internal - void* workspace_buffer{nullptr}; -}; - -inline bool IsKVBNMH(AttentionMode mode) { - switch (mode) { - case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: - case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - return true; - default: - return false; - } -} - -template -struct GemmSoftmaxGemmPermuteGenericPipeline { - static bool UseRawAttentionMask(const GemmSoftmaxGemmPermuteParams* params) { - return params->mask_index_buffer != nullptr && params->mask_index_dims.size() >= 2; - } - - static std::tuple GetWorkspacePlan(const GemmSoftmaxGemmPermuteParams* params) { - auto bytes = GetAttentionScratchSize( - sizeof(T), - params->attention->batch_size, - params->attention->num_heads, - params->attention->sequence_length, - params->attention->total_sequence_length); - auto gemm1_out = reinterpret_cast(params->workspace_buffer); - auto softmax_out = gemm1_out + (bytes / sizeof(T)); - auto gemm2_out = softmax_out + (bytes / sizeof(T)); - return {gemm1_out, softmax_out, gemm2_out}; - } - - inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { - return GetAttentionWorkspaceSize( - sizeof(T), - attn->batch_size, - attn->num_heads, - attn->head_size, - attn->sequence_length, - attn->total_sequence_length); - } - - inline static Status Gemm1(const GemmSoftmaxGemmPermuteParams* params) { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - - int k_buffer_stride = n * k; - if (IsKVBNMH(params->attention->mode)) { - k_buffer_stride = params->attention->max_sequence_length * params->attention->head_size; - } - - // GEMM1 [m,k] * [n,k]' -> [m,n] - return blas::row_major::StridedBatchedGemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::Trans, - m, n, k, - // For raw attention mask, the scalar is moved to softmax computation. - /*alpha=*/UseRawAttentionMask(params) ? 1.0f : params->scale, - params->q_buffer, k, m * k, - params->k_buffer, k, k_buffer_stride, - /*beta=*/0.0f, - gemm1_out, n, m * n, - batch); - } - - inline static Status SoftmaxRawMask(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { - // Softmax on [m,n] along the n dimension. - // Raw attention mask could be 2D (B,S) or 3D (B,S,T) or 4D(B,1,M,M), where M is the max sequence length. - auto attn = params->attention; - auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - T* persistent_softmax_workspace = gemm1_out; // replace Q*K' in place if persistent softmax is selected. - return ComputeSoftmaxWithRawMask( - params->Stream(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, - strides, buffer, nullptr, params->bias_buffer, gemm1_out, softmax_out, - attn->is_unidirectional, /* FIXME: this must not be attn.scale! */ params->scale, - use_persistent_softmax, persistent_softmax_workspace, attn->mask_filter_value); - } - - inline static Status Softmax1DIndexMask(const GemmSoftmaxGemmPermuteParams* params) { - auto mask_1d = params->mask_index_buffer; - auto mask_1d_size = params->mask_index_dims[0]; - // Softmax on [m,n] along the n dimension. - // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. - auto attn = params->attention; - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - const int* mask_start = (mask_1d_size > attn->batch_size) ? mask_1d + attn->batch_size : nullptr; - return ComputeSoftmaxWithMask1D( - params->StreamHandle(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, - mask_1d, mask_start, params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); - } - - inline static Status SoftmaxNoMask(const GemmSoftmaxGemmPermuteParams* params) { - // Softmax on [m,n] along the n dimension. - auto attn = params->attention; - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - return ComputeSoftmax( - params->StreamHandle(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, - params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); - } - - inline static Status Gemm2(const GemmSoftmaxGemmPermuteParams* params) { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - - int v_buffer_stride = n * o; - if (IsKVBNMH(params->attention->mode)) { - v_buffer_stride = params->attention->max_sequence_length * params->attention->v_head_size; - } - - // GEMM2 [m,n] * [n,o] -> [m,o] - // semantically, the output buffer contains B*N matrices of shape [S,H], compactly, thus B,N,S,H. - return blas::row_major::StridedBatchedGemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - m, o, n, - /*alpha=*/1.0f, - softmax_out, n, m * n, - params->v_buffer, o, v_buffer_stride, - /*beta=*/0.0f, - gemm2_out, o, m * o, - batch); - } - - inline static Status Permute0213(const GemmSoftmaxGemmPermuteParams* params) { - // Permute 0213 - // gemm2_out is B,N,S,H, transpose to out_buffer as B,S,N,H - auto attn = params->attention; - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - return LaunchTransCtx( - params->StreamHandle(), - attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, - params->device_prop->maxThreadsPerBlock, false, gemm2_out, params->out_buffer); - } - - static Status GetSupportedStatus(const GemmSoftmaxGemmPermuteParams* params) { - const auto& attn = params->attention; - // TODO: address the BNMH k,v strides - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: - case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: - case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: - case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: - if (attn->qkv_format == Q_K_V_BNSH) { - return Status::OK(); - } else { - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, got ", - attn->qkv_format); - } - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH but k, v are BLNH"); - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - // If sequence_length is 1, query of B1NH can be simply viewed as BN1H. - if (attn->sequence_length == 1) { - return Status::OK(); - } else { - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, ", - "only if sequence_length is 1, query of BSNH can be viewed as BNSH"); - } - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH"); - default: - return TUNABLE_OP_UNSUPPORTED("unknonw"); - } - return TUNABLE_OP_UNSUPPORTED("unknonw case"); - } - - static Status Run(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { - auto supported_status = GetSupportedStatus(params); - if (!supported_status.IsOK()) { - return supported_status; - } - ORT_RETURN_IF_ERROR(Gemm1(params)); - - if (UseRawAttentionMask(params)) { - ORT_RETURN_IF_ERROR(SoftmaxRawMask(params, use_persistent_softmax)); - } else if (params->mask_index_dims.size() == 1) { // 1d index mask - ORT_RETURN_IF_ERROR(Softmax1DIndexMask(params)); - } else { - ORT_RETURN_IF_ERROR(SoftmaxNoMask(params)); - } - - ORT_RETURN_IF_ERROR(Gemm2(params)); - ORT_RETURN_IF_ERROR(Permute0213(params)); - return Status::OK(); - } -}; - -template -class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp> { - public: - GemmSoftmaxGemmPermuteTunableOp(); - - inline static bool IsSupportedMode(const RocmAttentionParameters* attn) { - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: - // depends on qkv format - if (attn->qkv_format == Q_K_V_BNSH || attn->qkv_format == Q_K_V_BSNH) { - return true; - } else { - return false; - } - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: - return true; - default: - return false; - } - } - - inline static bool IsSupportedMaskType(const RocmAttentionParameters* attn) { - switch (attn->mask_type) { - case MASK_NONE: - case MASK_2D_DUMMY: - case MASK_2D_KEY_PADDING: - case MASK_3D_ATTENTION: - case MASK_4D_MEGATRON: - return true; - default: - return false; - } - } - - inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { - size_t num_bytes = GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(attn); - -#ifdef USE_COMPOSABLE_KERNEL - if (IsSupportedMaskType(attn)) { - auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(nullptr, attn); - num_bytes = std::max(num_bytes, sizeof(T) * sizes.x * sizes.y * sizes.z); - } -#endif - - return num_bytes; - } - - template - __global__ static void ConvertToFilledMaskValue( - T* __restrict__ out, - const int3 out_strides, - const int* __restrict__ mask_buffer, - const int3 mask_lengths, // [B,S,T] - const int3 mask_strides, - Converter cvt) { - const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; - if (global_idx >= mask_lengths.x * mask_lengths.y * CeilDiv(mask_lengths.z, VecSize)) { - return; - } - - const int tidx = (global_idx % CeilDiv(mask_lengths.z, VecSize)) * VecSize; - const int bs_idx = global_idx / CeilDiv(mask_lengths.z, VecSize); - const int sidx = bs_idx % mask_lengths.y; - const int bidx = bs_idx / mask_lengths.y; - - int64_t in_offset = mask_strides.x * bidx + mask_strides.y * sidx + mask_strides.z * tidx; - int64_t out_offset = out_strides.x * bidx + out_strides.y * sidx + out_strides.z * tidx; - - if (tidx + VecSize <= mask_lengths.z) { - using LoadT = const aligned_vector; - using StoreT = aligned_vector; - LoadT load = *reinterpret_cast(mask_buffer + in_offset); - StoreT store; - -#pragma unroll - for (int i = 0; i < VecSize; i++) { - store.val[i] = cvt(load.val[i]); - } - *reinterpret_cast(out + out_offset) = store; - } else { -#pragma unroll - for (int i = 0; i < mask_lengths.z - tidx; i++) { - out[out_offset + i] = cvt(mask_buffer[in_offset + i]); - } - } - } - - static Status LaunchConvertToFilledMaskValue(const GemmSoftmaxGemmPermuteParams* params) { - constexpr const int kThreadPerBlock = 256; - constexpr const int kVecSize = 4; - - auto attn = params->attention; - auto [buffer, lengths, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); - int64_t total_threads = lengths.x * lengths.y * CeilDiv(lengths.z, kVecSize); - auto num_blocks = CeilDiv(total_threads, kThreadPerBlock); - - auto mask_filter_value = attn->mask_filter_value; - auto cvt = [=] __device__(int v) -> T { - return v == 1 ? 0 : mask_filter_value; - }; - - ConvertToFilledMaskValue<<StreamHandle()>>>( - reinterpret_cast(params->workspace_buffer), {lengths.y * lengths.z, lengths.z, 1}, // out desc - buffer, lengths, strides, // mask desc - cvt); - - return HIP_CALL(hipGetLastError()); - } -}; - -#ifdef USE_COMPOSABLE_KERNEL - -template -auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { - constexpr const int kNumBiasBuffer = static_cast(USE_BIAS) + static_cast(USE_MASK); - - using Nop = ck::tensor_operation::element_wise::PassThrough; - using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), - "attention mode is not supported, got ", params->attention->mode); - if constexpr (USE_BIAS) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer == nullptr, "biased version only support input with bias"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer != nullptr, "non-biased version only support input without bias"); - } - if constexpr (USE_MASK) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), - "mask type is not supported, got ", params->attention->mask_type); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer == nullptr, "masked version only support input with mask"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); - } - - auto attn = params->attention; - const int& G0 = attn->batch_size; - const int& G1 = attn->num_heads; - const int& M = attn->sequence_length; - const int& N = attn->total_sequence_length; - const int& K = attn->head_size; - const int& O = attn->v_head_size; - { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); - } - - auto [qs, ks, vs] = GetQkvStrides(attn); - std::vector q_buffer_lengths = {G0, G1, M, K}; - std::vector q_buffer_strides = qs.template ForBNSHCoord>(); - std::vector k_buffer_lengths = {G0, G1, N, K}; - std::vector k_buffer_strides = ks.template ForBNSHCoord>(); - std::vector v_buffer_lengths = {G0, G1, O, N}; - std::vector v_buffer_strides = vs.template ForBNHSCoord>(); - std::vector out_buffer_lengths = {G0, G1, M, O}; - std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 - - std::array bias_buffers{}; - std::array, kNumBiasBuffer> bias_lengths{}; - std::array, kNumBiasBuffer> bias_strides{}; - if constexpr (USE_BIAS) { - bias_buffers[0] = const_cast(params->bias_buffer); - bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - bias_strides[0] = {G1 * M * N, M * N, N, 1}; - } - if constexpr (USE_MASK) { - bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; - bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - if (params->mask_index_dims.size() == 2) { // [B,T] - bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; - } else if (params->mask_index_dims.size() == 3) { // [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else { - ORT_ENFORCE(false, "Unreachable"); - } - } - - auto arg = impl->MakeArgumentPointer( - params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, - bias_buffers, // Gemm1 bias, as attention mask - {}, // Gemm2 bias - q_buffer_lengths, q_buffer_strides, - k_buffer_lengths, k_buffer_strides, - v_buffer_lengths, v_buffer_strides, - out_buffer_lengths, out_buffer_strides, - bias_lengths, bias_strides, - {}, - {}, - Nop{}, - Nop{}, - Acc0ElementOp{params->scale}, - Nop{}, - Nop{}); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - - if constexpr (USE_MASK) { - ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); - } - - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); -} - -template -auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using D0DataType = typename ck::detail::tuple_concat< - std::conditional_t, ck::Tuple<>>, - std::conditional_t, ck::Tuple<>>>::type; - - constexpr static auto MaskingSpecMaskDisabled = - ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; - constexpr static auto MaskingSpecMaskOutUpperTriangle = - ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; - - std::vector>>> - ret; - - for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) { - auto type_string = impl->GetTypeString(); - - auto invoker = impl->MakeInvokerPointer(); - auto op = [impl = std::move(impl), invoker = std::move(invoker)]( - const GemmSoftmaxGemmPermuteParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); - - return GetArgAndRunInvoker(impl, invoker, params); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); - } - - for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) { - auto type_string = impl->GetTypeString(); - - auto invoker = impl->MakeInvokerPointer(); - auto op = [impl = std::move(impl), invoker = std::move(invoker)]( - const GemmSoftmaxGemmPermuteParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle"); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->attention->sequence_length != params->attention->total_sequence_length, - "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); - - return GetArgAndRunInvoker(impl, invoker, params); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); - } - - return ret; -} -#endif // USE_COMPOSABLE_KERNEL - -template -GemmSoftmaxGemmPermuteTunableOp::GemmSoftmaxGemmPermuteTunableOp() { - this->RegisterOp([](const GemmSoftmaxGemmPermuteParams* params) { - return GemmSoftmaxGemmPermuteGenericPipeline::Run(params, false); - }); - -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } -#endif -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h deleted file mode 100644 index 0aff519d20e99..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "contrib_ops/cpu/bert/attention_common.h" -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, // Device Properties - RocmTuningContext* tuning_ctx, // context for tuning - Stream* stream, // ORT Stream - hipblasHandle_t& hipblas, // hipblas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden layer size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise.h b/onnxruntime/contrib_ops/rocm/bert/elementwise.h deleted file mode 100644 index 768295767835a..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise.h +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchElementwiseKernel(RocmTuningContext* tuning_ctx, Stream* stream, - const T* input, int input_length, - const T* bias, int bias_length, - T* output); - -// The following is LaunchElementwiseKernel implementation detail. Their interfaces are exposed for kernel explorer. -namespace internal { - -template -struct ElementwiseParams : OpParams { - ElementwiseParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, - const T* input, const T* bias, T* output, int input_length, int bias_length) - : OpParams(tuning_ctx, stream), - input(input), - bias(bias), - output(output), - input_length(input_length), - bias_length(bias_length) {} - - std::string Signature() const override { - std::string sig = std::to_string(input_length) + "_" + std::to_string(bias_length); - return sig; - } - - const T* input; - const T* bias; - T* output; - int input_length; - int bias_length; -}; - -template -class ElementwiseOp { - public: - Status operator()(const ElementwiseParams* params); - Status IsSupported(const ElementwiseParams* params); -}; - -template -Status ElementwiseStaticSelection(const ElementwiseParams* params); - -template -class ElementwiseTunableOp : public TunableOp> { - public: - ElementwiseTunableOp(); -}; - -} // namespace internal - -#define ELEMENTWISE_FWD_DECL(FnName, T) \ - namespace functor { \ - struct FnName; \ - } - -ELEMENTWISE_FWD_DECL(FastGeLU, float); -ELEMENTWISE_FWD_DECL(FastGeLU, double); -ELEMENTWISE_FWD_DECL(FastGeLU, half); -ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16); - -ELEMENTWISE_FWD_DECL(GeLU, float); -ELEMENTWISE_FWD_DECL(GeLU, double); -ELEMENTWISE_FWD_DECL(GeLU, half); -ELEMENTWISE_FWD_DECL(GeLU, BFloat16); - -ELEMENTWISE_FWD_DECL(ReLU, float); -ELEMENTWISE_FWD_DECL(ReLU, half); -ELEMENTWISE_FWD_DECL(ReLU, BFloat16); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh deleted file mode 100644 index 8255e70d27e48..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/tunable/util.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "contrib_ops/rocm/bert/elementwise.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -namespace functor { - -struct FastGeLU { - template - __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { - constexpr const float b = 0.7978845608028654f; // sqrt(2.0/M_PI) - - // const T cdf = a + a * _Tanh(in * (c * in * in + b)); - const T xb = x * T(b); - const T u = xb * T(0.044715f) * x * x + xb; - const T emu = __expf(-u - u); - const T cdf = T(1.0f) / (T(1.0f) + emu); - y = x * cdf; - } -}; - -struct GeLU { - template - __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { - y = T(0.5f) * x * (T(1.f) + T(erf(0.70710678118f * float(x)))); - } -}; - -struct ReLU { - template - __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { - y = x >= T{} ? x : T{}; - } -}; - -} // namespace functor - -using onnxruntime::rocm::CeilDiv; -using onnxruntime::rocm::GPU_WARP_SIZE; - -template -__global__ void ElementwiseKernel( - const T* __restrict__ input, int input_length, - const T* __restrict__ bias, int bias_length, - T* __restrict__ output) { - const int idx = blockIdx.x * TPB + threadIdx.x; - Fn f{}; - - if (idx < input_length) { - const T x = input[idx] + (bias == nullptr ? T{} : bias[idx % bias_length]); - f(output[idx], x); - } -} - -template -__global__ void ElementwiseKernelVec( - const T* __restrict__ input, int input_length, - const T* __restrict__ bias, int bias_length, - T* output) { - using VecT = onnxruntime::rocm::aligned_vector; - Fn f{}; - - const int idx = (blockIdx.x * TPB + threadIdx.x) * ILP; - if (idx < input_length) { - T input_v[ILP]; - VecT* input_val = reinterpret_cast(&input_v); - *input_val = *reinterpret_cast(&input[idx]); - T output_v[ILP]; - VecT* output_val = reinterpret_cast(&output_v); - T bias_v[ILP]; - if (bias != nullptr) { - VecT* bias_val = reinterpret_cast(&bias_v); - *bias_val = *reinterpret_cast(&bias[idx % bias_length]); - } - -#pragma unroll - for (int i = 0; i < ILP; i++) { - const T x = (bias == nullptr) ? input_v[i] : (T)(input_v[i] + bias_v[i]); - f(output_v[i], x); - } - *(reinterpret_cast(&output[idx])) = *output_val; - } -} - -template -Status LaunchElementwiseKernel( - RocmTuningContext* tuning_ctx, Stream* stream, - const T* input, int input_length, - const T* bias, int bias_length, - T* output) { - internal::ElementwiseParams params(tuning_ctx, stream, input, bias, output, input_length, bias_length); - if (tuning_ctx->IsTunableOpEnabled()) { - static internal::ElementwiseTunableOp op; - return op(¶ms); - } - - return internal::ElementwiseStaticSelection(¶ms); -} - -namespace internal { - -template -Status ElementwiseOp::operator()(const ElementwiseParams* params) { - dim3 blocks(CeilDiv(params->input_length, ThreadsPerBlock * VecSize)); - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, - params->bias, params->bias_length, - params->output); - return HIP_CALL(hipGetLastError()); -} - -template -Status ElementwiseOp::IsSupported(const ElementwiseParams* params) { - // TODO(anyone): Add tail handling for FastGelu - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->bias_length > 0 && params->bias_length % VecSize == 0 && params->input_length % VecSize == 0) || - (params->bias_length == 0 && params->input_length % VecSize == 0))); - // Avoid redundant configurations - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->input_length > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)); - - return Status::OK(); -} - -template -Status ElementwiseStaticSelection(const ElementwiseParams* params) { - constexpr int block_size = 256; - if constexpr (std::is_same_v) { - if (params->bias != nullptr) { - if (0 == (params->bias_length % 8) && (params->input_length >= 3145728)) { // 3145728=8*128*3072 - const int grid_size = (params->input_length / 8 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->bias_length % 4)) { - const int grid_size = (params->input_length / 4 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->bias_length % 2)) { - const int grid_size = (params->input_length / 2 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else { - const int grid_size = (params->input_length + block_size - 1) / block_size; - ElementwiseKernel<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } - } else { - if (0 == (params->input_length % 8) && (params->input_length >= 3145728)) { // 3145728=8*128*3072 - const int grid_size = (params->input_length / 8 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->input_length % 4)) { - const int grid_size = (params->input_length / 4 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->input_length % 2)) { - const int grid_size = (params->input_length / 2 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else { - const int grid_size = (params->input_length + block_size - 1) / block_size; - ElementwiseKernel<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } - } - } else { - const int grid_size = (params->input_length + block_size - 1) / block_size; - ElementwiseKernel<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } - return HIP_CALL(hipGetLastError()); -} - -template -ElementwiseTunableOp::ElementwiseTunableOp() { - this->RegisterOp(ElementwiseStaticSelection); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); -} - -#undef ADD_OP - -} // namespace internal - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime - -#define ELEMENTWISE_KERNEL_IMPL(Fn, T) \ - namespace onnxruntime { \ - namespace contrib { \ - namespace rocm { \ - template Status LaunchElementwiseKernel( \ - RocmTuningContext * tuning_ctx, Stream* stream, \ - const T* input, int input_length, \ - const T* bias, int bias_length, \ - T* output); \ - namespace internal { \ - template class ElementwiseTunableOp; \ - } \ - } \ - } \ - } diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu deleted file mode 100644 index c2a670ea76aca..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" - -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float); -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double); -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half); -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu deleted file mode 100644 index 97f0f74640c6e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" - -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double); -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float); -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half); -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu deleted file mode 100644 index 67e50869133f5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" - -ELEMENTWISE_KERNEL_IMPL(functor::ReLU, float); -ELEMENTWISE_KERNEL_IMPL(functor::ReLU, half); -ELEMENTWISE_KERNEL_IMPL(functor::ReLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc deleted file mode 100644 index fdb62d3a2aec5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/gemm_fast_gelu.h" - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" -#include "core/providers/cpu/math/matmul_helper.h" -#include "core/providers/rocm/rocm_common.h" - -using onnxruntime::rocm::ToHipType; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GemmFastGelu, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - GemmFastGelu); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) - -template -Status GemmFastGelu::ComputeInternal(OpKernelContext* ctx) const { - typedef typename ToHipType::MappedType HipT; - - const auto* X = ctx->Input(0); - const auto* W = ctx->Input(1); - const auto* bias = ctx->Input(2); - - bool transa = false; - bool transb = false; - bool trans_batch_a = false; - bool trans_batch_b = false; - - MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(X->Shape(), W->Shape(), transa, transb, trans_batch_a, trans_batch_b, false)); - - Tensor* Y = ctx->Output(0, helper.OutputShape()); - - // Bail out early if the output is going to be empty - if (Y->Shape().Size() == 0) - return Status::OK(); - - // gemmfastgelu only support alpha == 1 and beta == 0 - const HipT alpha = ToHipType::FromFloat(1.0f); - const HipT beta = ToHipType::FromFloat(0.0f); - - using onnxruntime::rocm::tunable::blas::BlasOp; - - return blas::row_major::GemmFastGelu( - GetTuningContext(), ctx->GetComputeStream(), GetHipblasHandle(ctx), - transa ? BlasOp::Trans : BlasOp::NonTrans, - transb ? BlasOp::Trans : BlasOp::NonTrans, - helper.M(), helper.N(), helper.K(), - alpha, - reinterpret_cast(X->Data()), helper.Lda(transa), - reinterpret_cast(W->Data()), helper.Ldb(transb), - (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, - beta, - reinterpret_cast(Y->MutableData()), helper.Ldc()); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h deleted file mode 100644 index ae4f84fa5f033..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using onnxruntime::rocm::RocmKernel; - -template -class GemmFastGelu final : public RocmKernel { - public: - GemmFastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {} - Status ComputeInternal(OpKernelContext* ctx) const override; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh deleted file mode 100644 index 77f53f9eed027..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#ifdef USE_COMPOSABLE_KERNEL -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" - -using onnxruntime::rocm::ToHipType; - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { -namespace internal { - -#ifdef USE_COMPOSABLE_KERNEL - -using onnxruntime::rocm::CKBlasOpAdaptor; -using onnxruntime::rocm::CKDataTypeAdaptor; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using Nop = ck::tensor_operation::element_wise::PassThrough; -using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; -using FastGelu = ck::tensor_operation::element_wise::FastGelu; - -template -auto GetCKGemmAddFastGeluTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using ALayout = typename CKBlasOpAdaptor::type; - using BLayout = typename CKBlasOpAdaptor::type; - using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< - ALayout, BLayout, ck::Tuple, Row, - CKDataType, CKDataType, ck::Tuple, CKDataType, - Nop, Nop, AddFastGelu>; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = onnxruntime::MakeString("withbias ", impl->GetTypeString()); - - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->alpha != one || params->beta != zero || params->bias == nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr"); - - auto nop = Nop{}; - auto addfastgelu = AddFastGelu{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, std::array{params->bias}, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, std::array{0}, params->ldc, - nop, nop, addfastgelu); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); - } - return ret; -} - -template -auto GetCKGemmFastGeluTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using ALayout = typename CKBlasOpAdaptor::type; - using BLayout = typename CKBlasOpAdaptor::type; - using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< - ALayout, BLayout, ck::Tuple<>, Row, - CKDataType, CKDataType, ck::Tuple<>, CKDataType, - Nop, Nop, FastGelu>; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = onnxruntime::MakeString("nobias ", impl->GetTypeString()); - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->alpha != one || params->beta != zero || params->bias != nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr"); - - auto nop = Nop{}; - auto fastgelu = FastGelu{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, - {}, - params->c, - params->m, params->n, params->k, - params->lda, params->ldb, - {}, - params->ldc, - nop, nop, fastgelu); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); - } - return ret; -} -#else -struct Row {}; -struct Col {}; -#endif // USE_COMPOSABLE_KERNEL - -} // namespace internal -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h deleted file mode 100644 index 2b8a21b83f177..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::tunable::blas::BlasOp; -using onnxruntime::rocm::tunable::blas::BlasOpToString; - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { - -template -struct GemmFastGeluParams : OpParams { - std::string Signature() const override { - bool has_bias = (nullptr != bias) ? 0 : 1; - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias); - } - hipblasHandle_t handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - T alpha; - const T* a; - int64_t lda; - const T* b; - int64_t ldb; - const T* bias; - T beta; - T* c; - int64_t ldc; -}; - -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu deleted file mode 100644 index 8d7e64b1015be..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES -#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" - -#include -#include - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" -#include "core/providers/rocm/shared_inc/fpgeneric.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { - -namespace row_major { - -template -inline GEMMFASTGELU(T, ScalarT) { - GemmFastGeluParams params; - params.tuning_ctx = tuning_ctx; - params.stream = stream; - params.handle = handle; - - params.opa = opa; - params.opb = opb; - params.m = m; - params.n = n; - params.k = k; - if constexpr (!std::is_same_v && std::is_same_v) { - params.alpha = ToHipType::FromFloat(std::forward(alpha)); - } else { - params.alpha = alpha; - } - params.a = a; - params.lda = lda; - params.b = b; - params.ldb = ldb; - params.bias = bias; - if constexpr (!std::is_same_v && std::is_same_v) { - params.beta = ToHipType::FromFloat(std::forward(beta)); - } else { - params.beta = beta; - } - params.c = c; - params.ldc = ldc; - - if (tuning_ctx->IsTunableOpEnabled()) { - if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } - } - - return internal::GemmFastGeluUnfused(¶ms); -} - -#define CALL_GEMMFASTGELU(T, ScalarT) \ - GemmFastGelu(tuning_ctx, stream, handle, \ - opa, opb, \ - m, n, k, \ - alpha, a, lda, b, ldb, bias, \ - beta, c, ldc) - -// clang-format off -GEMMFASTGELU(float, float ) { return CALL_GEMMFASTGELU(float, float ); } -GEMMFASTGELU(half, half ) { return CALL_GEMMFASTGELU(half, half ); } -GEMMFASTGELU(BFloat16, BFloat16) { return CALL_GEMMFASTGELU(BFloat16, BFloat16); } -GEMMFASTGELU(half, float ) { return CALL_GEMMFASTGELU(half, float ); } -GEMMFASTGELU(BFloat16, float ) { return CALL_GEMMFASTGELU(BFloat16, float ); } -// clang-format on - -#undef CALL_GEMMFASTGELU - -} // namespace row_major - -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h deleted file mode 100644 index b707c63ef44be..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "core/common/status.h" -#include "core/common/float16.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { - -#define GEMMFASTGELU(T, ScalarT) \ - common::Status GemmFastGelu( \ - RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ - BlasOp opa, BlasOp opb, \ - std::int64_t m, std::int64_t n, std::int64_t k, \ - ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ - const T* bias, ScalarT beta, T* c, std::int64_t ldc) - -namespace row_major { - -GEMMFASTGELU(float, float); -GEMMFASTGELU(half, half); -GEMMFASTGELU(BFloat16, BFloat16); -GEMMFASTGELU(half, float); -GEMMFASTGELU(BFloat16, float); - -} // namespace row_major - -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime - -#ifndef _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES -#undef GEMMFASTGELU -#endif diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh deleted file mode 100644 index e157aa57f8c43..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "contrib_ops/rocm/bert/elementwise.h" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/gemm_hipblaslt.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { -namespace internal { - -using namespace onnxruntime::rocm::tunable::blas::internal; - -template -Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { - namespace column_major = onnxruntime::rocm::tunable::blas::column_major; - ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning_ctx, params->stream, params->handle, - params->opb, params->opa, - params->n, params->m, params->k, - params->alpha, params->b, params->ldb, params->a, params->lda, - params->beta, params->c, params->ldc)); - - int64_t fast_gelu_input_length = params->m * params->n; - int64_t bias_length = (params->bias != nullptr) ? params->n : 0; - - // Because of GemmFastGeluUnfused is a combination of GemmOp and FastGeluOp, FastGeluOp in this combination is - // an inplace computation. - // 1. If we call GemmFastGeluUnfused directly with enabled tuning, it may cause the input buffer of FastGelu been - // updated accumulatedly and result in incorrect result finally. This only happens if the tuning's FindFastest is invoked. - // 2. It's safe to call GemmFastGeluUnfused with disabled tuning, FastGelu only run once and produce correct result. - // 3. It's safe to call GemmFastGeluUnfused as part of GemmFastGeluTunableOp with enable tuning, GemmTunableOp and - // FastGeluTunableOp will do tune in first warmup step separately during GemmFastGeluUnfused profiling process. - // After that, the call to GemmFastGeluUnfused not invoke tuning's FindFastest of FastGelu. - // - // Note: If any change cause directly usage of GemmFastGeluUnfused, add PreTuning() and PostTuning() in FastGeluTunableOp - // to protect original input value. - return onnxruntime::contrib::rocm::LaunchElementwiseKernel( - params->tuning_ctx, params->Stream(), - params->c, static_cast(fast_gelu_input_length), - params->bias, static_cast(bias_length), - params->c); -} - -template -class GemmFastGeluTunableOp : public TunableOp> { - public: - GemmFastGeluTunableOp() { - this->RegisterOp(GemmFastGeluUnfused); -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - -#ifdef USE_HIPBLASLT - for (auto&& [_, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - } -}; - -} // namespace internal -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu deleted file mode 100644 index 09a6550549614..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ /dev/null @@ -1,530 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/platform/env_var_utils.h" -#include "contrib_ops/rocm/bert/group_query_attention.h" -#include "contrib_ops/cpu/bert/group_query_attention_helper.h" -#include "contrib_ops/rocm/bert/rotary_embedding_impl.h" -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" - -#ifdef USE_COMPOSABLE_KERNEL_CK_TILE -#include "ck_tile/core/numeric/integer.hpp" -#include "fmha_fwd.hpp" -#endif - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GroupQueryAttention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("M", DataTypeImpl::GetTensorType()) \ - .MayInplace(3, 1) \ - .MayInplace(4, 2) \ - .InputMemoryType(OrtMemTypeCPUInput, 6), \ - GroupQueryAttention); - -// REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) -// REGISTER_KERNEL_TYPED(BFloat16) - -template -std::string GetCkFmhaDataTypeString(); - -template <> -std::string GetCkFmhaDataTypeString() { - return "fp16"; -} - -template <> -std::string GetCkFmhaDataTypeString() { - return "bf16"; -} - -__global__ void seqlens_inc_kernel(const int* seqlens, int* out, int num_elems, int inc) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - if (idx < num_elems) { - out[idx] = seqlens[idx] + inc; - } -} - -Status LaunchSeqlensInc(hipStream_t stream, const int* seqlens, int* out, int num_elems, int inc) { - constexpr int NumThreads = 128; - int num_blks = CeilDiv(num_elems, NumThreads); - seqlens_inc_kernel<<>>(seqlens, out, num_elems, inc); - return HIP_CALL(hipGetLastError()); -} - -__global__ void seqstart_init_kernel(int* out, int num_elems, int length_per_seq) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - if (idx < num_elems) { - out[idx] = idx * length_per_seq; - } - if (idx == 0) { - out[num_elems] = num_elems * length_per_seq; - } -} - -Status LaunchSeqStartInit(hipStream_t stream, int* out, int num_elems, int length_per_seq) { - constexpr int NumThreads = 128; - int num_blks = CeilDiv(num_elems, NumThreads); - seqstart_init_kernel<<>>(out, num_elems, length_per_seq); - return HIP_CALL(hipGetLastError()); -} - -// Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, - const int batch_size) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - int b = tid / seqlen; - int s = tid % seqlen; - if (b < batch_size) { - if (s < seqlens_k[b] + 1) { - position_ids[tid] = s; - } else { - position_ids[tid] = 1; - } - } -} - -// Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid < batch_size) { - position_ids[tid] = seqlens_k[tid]; - } -} - -// Convert seqlens_k to position_ids -Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, - int64_t* position_ids, hipStream_t stream, const int max_threads_per_block) { - const int seqlen = parameters.sequence_length; - const int batch_size = parameters.batch_size; - const int threads = max_threads_per_block; - const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_first_prompt) { - SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); - } else { - SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); - } - return HIP_CALL(hipGetLastError()); -} - -template -GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) - : RocmKernel(info) { - int64_t num_heads = 0; - int64_t kv_num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); - num_heads_ = static_cast(num_heads); - kv_num_heads_ = static_cast(kv_num_heads); - is_past_bsnh_ = false; - is_unidirectional_ = true; - local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); - do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; - rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; - scale_ = info.GetAttrOrDefault("scale", 0.0f); -} - -template <> -std::once_flag GroupQueryAttention::arch_checking_{}; - -template <> -std::once_flag GroupQueryAttention::arch_checking_{}; - -template -Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { -#if USE_COMPOSABLE_KERNEL_CK_TILE - auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); - const Tensor* query = ctx->Input(0); - const Tensor* key = ctx->Input(1); - const Tensor* value = ctx->Input(2); - const Tensor* past_key = ctx->Input(3); - const Tensor* past_value = ctx->Input(4); - const Tensor* seqlens_k = ctx->Input(5); - const Tensor* total_seqlen = ctx->Input(6); - const Tensor* cos_cache = ctx->Input(7); - const Tensor* sin_cache = ctx->Input(8); - - auto& device_prop = GetDeviceProp(); - std::call_once( - arch_checking_, - [](const hipDeviceProp_t& device_prop) { - if (std::string_view(device_prop.gcnArchName).find("gfx90a") == std::string_view::npos && - std::string_view(device_prop.gcnArchName).find("gfx942") == std::string_view::npos) { - LOGS_DEFAULT(WARNING) - << "GroupQueryAttention currently only supports ck_tile fmha backend which only supports " - << "CDNA2 and CDNA3 archs."; - LOGS_DEFAULT(WARNING) - << "GroupQueryAttention running on an unsuppoted GPU may result in " - << "hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error."; - } - }, - device_prop); - - GroupQueryAttentionParameters parameters; - using HipT = typename ToHipType::MappedType; - - const int max_thr_per_blk = device_prop.maxThreadsPerBlock; - - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, - key, - value, - past_key, - past_value, - cos_cache, - sin_cache, - ¶meters, - num_heads_, - kv_num_heads_, - seqlens_k, - total_seqlen, - is_past_bsnh_, - scale_, - max_thr_per_blk)); - - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; - const int kv_num_heads = parameters.kv_num_heads; - const int head_size = parameters.head_size; - AttentionQkvFormat past_kv_format = parameters.past_kv_format; - - parameters.local_window_size = local_window_size_; - parameters.is_unidirectional = is_unidirectional_; - // parameters.zeros_count = kZerosCount; - // parameters.zero_ptr = zeros_.get(); - // parameters.left_padding = left_padding_; - parameters.do_rotary = do_rotary_; - parameters.rotary_interleaved = rotary_interleaved_; - - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( - context->OutputCount(), - static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); - - if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); - } - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(batch_size); - output_shape[1] = static_cast(sequence_length); - output_shape[2] = static_cast(parameters.hidden_size); - Tensor* output = ctx->Output(0, output_shape); - Strides output_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); - - int4 past_shape; - std::vector present_dims; - Strides present_strides; - Strides past_strides; - if (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { - past_shape = { - batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size}; - past_strides = Strides::BSNHMemory( - batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size); - present_dims = { - batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size}; - present_strides = Strides::BSNHMemory( - batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); - } else { // BNSH - past_shape = { - batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size}; - past_strides = Strides::BNSHMemory( - batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size); - present_dims = { - batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size}; - present_strides = Strides::BNSHMemory( - batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size); - } - TensorShape present_shape(present_dims); - Tensor* present_key = ctx->Output(1, present_shape); - Tensor* present_value = ctx->Output(2, present_shape); - - Strides query_strides; - Strides key_strides; - Strides value_strides; - int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; // BNSH coord - const HipT* query_ptr = reinterpret_cast(query->DataRaw()); - const HipT* key_ptr; - const HipT* value_ptr; - if (!parameters.is_packed_qkv) { - query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); - key_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); - value_strides = key_strides; - key_ptr = reinterpret_cast(key->DataRaw()); - value_ptr = reinterpret_cast(value->DataRaw()); - } else { - query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); - key_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); - value_strides = query_strides; - const size_t key_offset = static_cast(num_heads * head_size); - const size_t value_offset = static_cast(kv_num_heads * head_size); - key_ptr = query_ptr + key_offset; - value_ptr = key_ptr + value_offset; - } - - IAllocatorUniquePtr rotary_q_tmp; - IAllocatorUniquePtr rotary_k_tmp; - if (parameters.do_rotary) { - size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); - size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); - auto rotary_q_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); - auto rotary_k_strides = Strides::BSNHMemory(batch_size, sequence_length, kv_num_heads, head_size); - - rotary_q_tmp = GetScratchBuffer(q_size, ctx->GetComputeStream()); - rotary_k_tmp = GetScratchBuffer(k_size, ctx->GetComputeStream()); - auto rotary_position_ids_tmp = GetScratchBuffer(sequence_length * batch_size, ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, - reinterpret_cast(seqlens_k->DataRaw()), - reinterpret_cast(rotary_position_ids_tmp.get()), - hip_stream, max_thr_per_blk)); - // Launch rotary embedding kernel - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_q_tmp.get(), query_ptr, - reinterpret_cast(rotary_position_ids_tmp.get()), - reinterpret_cast(cos_cache->DataRaw()), - reinterpret_cast(sin_cache->DataRaw()), - parameters.batch_size, parameters.sequence_length, - parameters.num_heads, parameters.head_size, - parameters.rotary_dim, parameters.seqlen_present_kv_cache, - /*position_ids_format*/ 1, parameters.rotary_interleaved, - max_thr_per_blk, - query_strides.ForBNSHCoord(), - rotary_q_strides.ForBNSHCoord())); - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_k_tmp.get(), key_ptr, - reinterpret_cast(rotary_position_ids_tmp.get()), - reinterpret_cast(cos_cache->DataRaw()), - reinterpret_cast(sin_cache->DataRaw()), - parameters.batch_size, parameters.sequence_length, - parameters.kv_num_heads, parameters.head_size, - parameters.rotary_dim, parameters.seqlen_present_kv_cache, - /*position_ids_format*/ 1, parameters.rotary_interleaved, - max_thr_per_blk, - key_strides.ForBNSHCoord(), - rotary_k_strides.ForBNSHCoord())); - query_ptr = reinterpret_cast(rotary_q_tmp.get()); - key_ptr = reinterpret_cast(rotary_k_tmp.get()); - query_strides = rotary_q_strides; - key_strides = rotary_k_strides; - } - - const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; - IAllocatorUniquePtr seqlens_k_tmp; - - // build present kv cache - auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); - auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); - if (parameters.is_first_prompt) { - // copy prompt kv to present kv - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), - present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), - present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - } else { - const auto* past_key_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_key->DataRaw()); - const auto* past_value_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_value->DataRaw()); - parameters.kv_share_buffer = past_key_ptr == present_key_ptr; // FIXME: - if (!parameters.kv_share_buffer) { - // copy past to present, - // NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are - // not the same, aka, can not be as simple as strided - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_key_ptr, past_shape, past_strides.ForBNSHCoord(), - present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(), - present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - } else { - // In the case of share buffer - ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr); - ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr); - } - // then append new kv to present - size_t buffer_offset = seqlens_k ? 0 : present_strides.OffsetAt(0, 0, kv_sequence_length, 0); - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, - present_key_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, - max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, - present_value_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, - max_thr_per_blk)); - - // NOTE: ORT: seqlens_k Indicates past sequence lengths for token generation case. - // we should call fmha with total sequence lengths - seqlens_k_tmp = GetScratchBuffer(batch_size * sizeof(int), ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), batch_size, sequence_length)); - seqlens_k_ptr = seqlens_k_tmp.get(); - } - static_assert(std::is_same_v); - - const float scale = parameters.scale == 0.0f - ? 1.f / sqrt(static_cast(parameters.head_size)) - : parameters.scale; - bias_enum bias_type = bias_enum::no_bias; - - mask_info mask = [&]() { - if (local_window_size_ != -1) { - mask_info ret; - ret.type = mask_enum::window_generic; - ret.left = local_window_size_; - ret.right = parameters.is_unidirectional ? 0 : -1; - // ret.x = kv_sequence_length - (sequence_length - ret.left); - // ret.y = sequence_length + (ret.right - kv_sequence_length); - return ret; - } - - if (parameters.is_first_prompt && is_unidirectional_) { - return mask_info::decode("t", sequence_length, kv_sequence_length); - } - - return mask_info::decode("0", sequence_length, kv_sequence_length); - }(); - - auto seqstart_q_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); - auto seqstart_k_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqStartInit( - hip_stream, seqstart_q_tmp.get(), batch_size, - query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); - ORT_RETURN_IF_ERROR(LaunchSeqStartInit( - hip_stream, seqstart_k_tmp.get(), batch_size, - present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); - - fmha_fwd_args args{ - query_ptr, - present_key->DataRaw(), - present_value->DataRaw(), - nullptr, // bias, alibi/element - nullptr, // lse, logsumexp buffer - output->MutableDataRaw(), - seqstart_q_tmp.get(), // seqstart_q_ptr, for group mode - seqstart_k_tmp.get(), // seqstart_k_ptr, for group mode - seqlens_k_ptr, // seqlen_k_ptr, for group mode - sequence_length, // seqlen_q, for batch mode - kv_sequence_length, // seqlen_k, for batch mode - parameters.batch_size, // batch - parameters.sequence_length, // max_seqlen_q - parameters.head_size, // hdim_q - parameters.head_size, // hdim_v - parameters.num_heads, - parameters.kv_num_heads, - scale, - 1.0f, // scale_p of squant, useless - 1.0f, // scale_o of squant, useless - static_cast(query_strides.strides_for_bnsh_coord.z), // stride_q, to be regarded as stride of dim S - static_cast(present_strides.strides_for_bnsh_coord.z), // stride_k, to be regarded as stride of dim S - static_cast(present_strides.strides_for_bnsh_coord.z), // stride_v, to be regarded as stride of dim S - batch_size, // stride_bias, if alibi, b*h need set this to h, 1*h need set this to 0 - static_cast(output_strides.strides_for_bnsh_coord.z), // stride_o, to be regarded as stride of dim S - static_cast(query_strides.strides_for_bnsh_coord.y), // nhead_stride_q, to be regarded as stride of dim N - static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_k, to be regarded as stride of dim N - static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_v, to be regarded as stride of dim N - 0, // nhead_stride_bias - batch_size, // nhead_stride_lse - static_cast(output_strides.strides_for_bnsh_coord.y), // batch_stride_o, to be regarded as stride of dim B - static_cast(query_strides.strides_for_bnsh_coord.x), // batch_stride_q, to be regarded as stride of dim B - static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_k, to be regarded as stride of dim B - static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_v, to be regarded as stride of dim B - 0, // batch_stride_bias - num_heads * batch_size, // batch_stride_lse - static_cast(output_strides.strides_for_bnsh_coord.x), // batch_stride_o, to be regarded as stride of dim B - mask.left, // window_size_left - mask.right, // window_size_right - static_cast(mask.type)}; - -#if 0 - std::cout - << "\n sequence_length:" << sequence_length - << "\n kv_sequence_length:" << kv_sequence_length - << "\n seqlen_past_kv_cache:" << parameters.seqlen_past_kv_cache - << "\n seqlen_present_kv_cache:" << parameters.seqlen_present_kv_cache << std::endl; - - std::cout - << "\n q_ptr:" << args.q_ptr - << "\n k_ptr:" << args.k_ptr - << "\n v_ptr:" << args.v_ptr - << "\n bias_ptr:" << args.bias_ptr - << "\n lse_ptr:" << args.lse_ptr - << "\n o_ptr:" << args.o_ptr - << "\n seqstart_q_ptr:" << args.seqstart_q_ptr - << "\n seqstart_k_ptr:" << args.seqstart_k_ptr - << "\n seqlen_k_ptr:" << args.seqlen_k_ptr - << "\n seqlen_q:" << args.seqlen_q - << "\n seqlen_k:" << args.seqlen_k - << "\n batch:" << args.batch - << "\n max_seqlen_q:" << args.max_seqlen_q - << "\n hdim_q:" << args.hdim_q - << "\n hdim_v:" << args.hdim_v - << "\n nhead_q:" << args.nhead_q - << "\n nhead_k:" << args.nhead_k - << "\n scale_s:" << args.scale_s - << "\n scale_p:" << args.scale_p - << "\n scale_o:" << args.scale_o - << "\n stride_q:" << args.stride_q - << "\n stride_k:" << args.stride_k - << "\n stride_v:" << args.stride_v - << "\n stride_bias:" << args.stride_bias - << "\n stride_o:" << args.stride_o - << "\n nhead_stride_q:" << args.nhead_stride_q - << "\n nhead_stride_k:" << args.nhead_stride_k - << "\n nhead_stride_v:" << args.nhead_stride_v - << "\n nhead_stride_bias:" << args.nhead_stride_bias - << "\n nhead_stride_lse:" << args.nhead_stride_lse - << "\n nhead_stride_o:" << args.nhead_stride_o - << "\n batch_stride_q:" << args.batch_stride_q - << "\n batch_stride_k:" << args.batch_stride_k - << "\n batch_stride_v:" << args.batch_stride_v - << "\n batch_stride_bias:" << args.batch_stride_bias - << "\n batch_stride_lse:" << args.batch_stride_lse - << "\n batch_stride_o:" << args.batch_stride_o - << "\n window_size_left:" << args.window_size_left - << "\n window_size_right:" << args.window_size_right - << "\n mask_type:" << args.mask_type - << std::endl; -#endif - - fmha_fwd_traits traits{ - parameters.head_size, - parameters.head_size, // v head size - GetCkFmhaDataTypeString(), - !parameters.is_first_prompt, // true, // is_group_mode - true, // is_v_rowmajor ? dim is fastest : seq is fastest - mask.type, - bias_type, - false, // has_lse - false, // do_fp8_static_quant, aka, squant - }; - - ck_tile::stream_config stream_config{ - hip_stream, - false // time_kernel - }; - - auto duration = fmha_fwd(traits, args, stream_config); - if (duration < 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "fmha_fwd internal error"); - } - HIP_RETURN_IF_ERROR(hipGetLastError()); - - return Status::OK(); -#else - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GroupQueryAttention requires ck_tile to be enabled"); -#endif -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h deleted file mode 100644 index ce0de1f761aa5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class GroupQueryAttention final : public RocmKernel { - public: - GroupQueryAttention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - int num_heads_; // number of attention heads - int kv_num_heads_; // different for k and v for group query attention - int local_window_size_; - bool is_unidirectional_; - bool is_past_bsnh_; - bool do_rotary_; - bool rotary_interleaved_; - float scale_; - - private: - static std::once_flag arch_checking_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh deleted file mode 100644 index 2eeb7c3e8f279..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh +++ /dev/null @@ -1,270 +0,0 @@ -#include "hip/hip_runtime.h" -/* - The implementation of this file is based on bert plugins in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#pragma once - -#include -#include -#include -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/shared_inc/rocm_call.h" - -using namespace onnxruntime::rocm; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -__device__ inline T Rsqrt(const T& x); - -template <> -__device__ inline float Rsqrt(const float& x) { - return rsqrtf(x); -} - -template <> -__device__ inline half Rsqrt(const half& x) { -#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) - return hrsqrt(x); -#else - return half(rsqrtf(static_cast(x))); -#endif -} - -__device__ inline half2 AddHalf2(const half2 a, const half2 b) { -#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) - return __hadd2(a, b); -#else - return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y)); -#endif -} - -struct KeyValuePairSum { - __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, - const hipcub::KeyValuePair& b) { - return hipcub::KeyValuePair(a.key + b.key, a.value + b.value); - } - - __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, - const hipcub::KeyValuePair& b) { - const half2 a2 = __halves2half2(a.key, a.value); - const half2 b2 = __halves2half2(b.key, b.value); - const half2 res = AddHalf2(a2, b2); - return hipcub::KeyValuePair(__low2half(res), __high2half(res)); - } - - __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, - const hipcub::KeyValuePair& b) { - return hipcub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); - } -}; - -template -__device__ inline void LayerNorm( - const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, - const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - - using BlockReduce = hipcub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mu; // mean - __shared__ U rsigma; // 1 / std.dev. - - KeyValuePairSum pair_sum; - const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); - - if (threadIdx.x == 0) { - mu = sum_kv.key; - rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); - } - __syncthreads(); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const U val = static_cast(output[idx]); - const U g = static_cast(gamma[i]); - const U b = (nullptr == beta) ? U(0.f) : static_cast(beta[i]); - output[idx] = static_cast(g * (val - mu) * rsigma + b); - } -} - -template -__device__ inline void SimplifiedLayerNorm( - const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U rsigma; // 1 / std.dev. - - const U sum = BlockReduce(temp_storage).Sum(thread_data); - - if (threadIdx.x == 0) { - rsigma = Rsqrt(sum + epsilon); - } - __syncthreads(); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const U val = static_cast(output[idx]); - const U g = static_cast(gamma[i]); - output[idx] = static_cast(g * val * rsigma); - } -} - -template -__device__ inline void SimplifiedLayerNormVec( - const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U rsigma; // 1 / std.dev. - - const U sum = BlockReduce(temp_storage).Sum(thread_data); - - if (threadIdx.x == 0) { - rsigma = Rsqrt(sum + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { - int idx = offset + i; - const VecV gamma_v = *reinterpret_cast(gamma + i); - VecV output_v = *reinterpret_cast(output + idx); - -#pragma unroll - for (int k = 0; k < ILP; k++) { - output_v.val[k] = U(gamma_v.val[k]) * U(output_v.val[k]) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } - } -} - -template -__device__ inline void LayerNormVec( - const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, - const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mu; // mean - __shared__ U rsigma; // 1 / std.dev. - - KeyValuePairSum pair_sum; - const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); - - if (threadIdx.x == 0) { - mu = sum_kv.key; - rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { - int idx = offset + i; - const VecV beta_v = (beta != nullptr) ? *reinterpret_cast(beta + i) : VecV(); - const VecV gamma_v = *reinterpret_cast(gamma + i); - VecV output_v = *reinterpret_cast(output + idx); - -#pragma unroll - for (int k = 0; k < ILP; k++) { - output_v.val[k] = (beta != nullptr) ? U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma + U(beta_v.val[k]) : U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } - } -} - -template -__device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePair& thread_data, - const int ld, const int idx, const V* beta, const V* gamma, - const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - // Small settings: the block covers the leading dimension TPB >= ld. The input - // value is available in a register - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mu; // mean - __shared__ U rsigma; // 1 / std.dev. - - KeyValuePairSum pair_sum; - const hipcub::KeyValuePair sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); - - if (threadIdx.x == 0) { - mu = sum_kv.key; - rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - const VecV beta_v = (beta != nullptr) ? *reinterpret_cast(beta + threadIdx.x * ILP) : VecV(); - const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); - VecV output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - output_v.val[i] = (beta != nullptr) ? U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma + U(beta_v.val[i]) : U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } -} - -template -__device__ inline void SimplifiedLayerNormSmall(const T* input_v, const U& thread_data, const int ld, const int idx, - const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - // Small settings: the block covers the leading dimension TPB >= ld. The input - // value is available in a register - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U rsigma; // 1 / std.dev. - - const U sum = BlockReduce(temp_storage).Sum(thread_data); - - if (threadIdx.x == 0) { - rsigma = Rsqrt(sum + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); - VecV output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - output_v.val[i] = U(gamma_v.val[i]) * U(input_v[i]) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu deleted file mode 100644 index 5d4ef53b8ba97..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/multihead_attention.h" - -#include "contrib_ops/cpu/bert/multihead_attention_helper.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" -#include "core/platform/env_var_utils.h" -#include "core/providers/rocm/rocm_common.h" - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_MHA_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - MultiHeadAttention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - MultiHeadAttention) - -REGISTER_MHA_KERNEL_TYPED(float); -REGISTER_MHA_KERNEL_TYPED(MLFloat16); - -static constexpr int kPastSequenceLengthInputIndex = 7; -static constexpr int kBeamWidthInputIndex = 8; -static constexpr int kPastInputIndex = 5; -static constexpr int kPresentOutputIndex = 1; - -#define REGISTER_DMMHA_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - DecoderMaskedMultiHeadAttention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(kPastInputIndex, kPresentOutputIndex) \ - .MayInplace(kPastInputIndex + 1, kPresentOutputIndex + 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \ - .InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \ - MultiHeadAttention) - -REGISTER_DMMHA_KERNEL_TYPED(float); -REGISTER_DMMHA_KERNEL_TYPED(MLFloat16); - -template -MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) - : RocmKernel(info), - attn_type_(info.node().OpType() == "DecoderMaskedMultiHeadAttention" ? kDecoderMaskedMultiHeadAttention - : kMultiHeadAttention) { - int64_t num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - num_heads_ = static_cast(num_heads); - - mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); - - scale_ = info.GetAttrOrDefault("scale", 0.0f); - - past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL) != 0LL; - is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - - using HipT = typename ToHipType::MappedType; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - tunable_op_ = std::make_shared(); -} - -template -Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { - ORT_ENFORCE( - GetTuningContext()->IsTunableOpEnabled(), - "MultiHeadAttention of ROCm EP is only supported if tunable op is used and tuning is enabled."); - - const Tensor* query = context->Input(0); - const Tensor* key = context->Input(1); - const Tensor* value = context->Input(2); - - const Tensor* bias{}; - const Tensor* key_padding_mask{}; - const Tensor* attention_bias{}; - const Tensor* past_key{}; - const Tensor* past_value{}; - const Tensor* past_seq_len{}; - - const Tensor* cache_indirection = nullptr; - - if (attn_type_ == kMultiHeadAttention) { - bias = context->Input(3); - key_padding_mask = context->Input(4); - attention_bias = context->Input(5); - past_key = context->Input(6); - past_value = context->Input(7); - } else if (attn_type_ == kDecoderMaskedMultiHeadAttention) { - key_padding_mask = context->Input(3); - attention_bias = context->Input(4); - past_key = context->Input(5); - past_value = context->Input(6); - past_seq_len = context->Input(kPastSequenceLengthInputIndex); - // const Tensor* beam_width = context->Input(8); // NOTE: not used - // const Tensor* cache_indirection = context->Input(9); // TODO: should not present for ROCm EP - bias = context->Input(10); - } - - if (nullptr != bias) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "qkv_bias is not supported on ROCm EP. " - "User should fuse the qkv bias to qkv projection instead."); - } - - auto& device_prop = GetDeviceProp(); - RocmAttentionParameters attn; - ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, - key, - value, - bias, - key_padding_mask, - attention_bias, - past_key, - past_value, - cache_indirection, - past_seq_len, - &attn, /* parameters */ - num_heads_, - mask_filter_value_, - scale_, - is_unidirectional_, - past_present_share_buffer_, - attn_type_, - device_prop.maxThreadsPerBlock)); - - if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention"); - } - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(attn.batch_size); - output_shape[1] = static_cast(attn.sequence_length); - output_shape[2] = static_cast(attn.v_hidden_size); - Tensor* output = context->Output(0, output_shape); - - std::vector present_dims{ - attn.batch_size, - attn.num_heads, - past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, - attn.head_size, - }; - TensorShape present_shape(present_dims); - Tensor* present_key = context->Output(1, present_shape); - Tensor* present_value = context->Output(2, present_shape); - - ORT_RETURN_IF_ERROR(ClassifyAttentionMode( - attn_type_, &attn, - /*qkv=*/{query, key, value}, - /*past=*/{past_key, past_value}, - /*present=*/{present_key, present_value})); - - using HipT = typename ToHipType::MappedType; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - auto workspace_bytes = AttentionTunableOp::GetWorkspaceNumBytes(&attn); - auto workspace = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); - - hipStream_t stream = Stream(context); - if (nullptr != present_key) { // process past present concat - Strides dst_strides; - - int4 past_shape; - Strides past_src_strides; - const HipT* past_key_src; - const HipT* past_value_src; - HipT* past_key_dst{}; - HipT* past_value_dst{}; - - int4 add_shape; - Strides add_src_strides; - const HipT* add_key_src = reinterpret_cast(key->DataRaw()); - const HipT* add_value_src = reinterpret_cast(value->DataRaw()); - HipT* add_key_dst; - HipT* add_value_dst; - - if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH || - attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { - dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - - past_shape = {attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; - past_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); - past_key_src = reinterpret_cast(past_key->DataRaw()); - past_value_src = reinterpret_cast(past_value->DataRaw()); - past_key_dst = reinterpret_cast(present_key->MutableDataRaw()); - past_value_dst = reinterpret_cast(present_value->MutableDataRaw()); - - if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH) { - add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); - } else if (attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { - add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); - } - } else if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH || - attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { - dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - - if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH) { - add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); - } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { - add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); - } - } else if ( - attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || - attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || - attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH || - attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { - dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); - - if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH) { - add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); - } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { - add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "past present concatenation is not implemented for attention mode ", attn.mode); - } - add_shape = {attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size}; // kernel in coord (b,n,s,h) - add_key_dst = reinterpret_cast(present_key->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - add_value_dst = reinterpret_cast(present_value->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - - if (past_key_dst) { - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, past_key_src, past_shape, past_src_strides.ForBNSHCoord(), - past_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } - if (past_value_dst) { - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, past_value_src, past_shape, past_src_strides.ForBNSHCoord(), - past_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } - - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, add_key_src, add_shape, add_src_strides.ForBNSHCoord(), - add_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, add_value_src, add_shape, add_src_strides.ForBNSHCoord(), - add_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } - - GemmSoftmaxGemmPermuteParams params; - params.tuning_ctx = GetTuningContext(); - params.stream = context->GetComputeStream(); - params.handle = GetHipblasHandle(context); - params.attention = &attn; - params.device_prop = &device_prop; - params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_; - std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = ConvertToOffsetedBufferViews( - &attn, - nullptr == query ? nullptr : reinterpret_cast(query->DataRaw()), - nullptr == key ? nullptr : reinterpret_cast(key->DataRaw()), - nullptr == value ? nullptr : reinterpret_cast(value->DataRaw()), - nullptr == present_key ? nullptr : reinterpret_cast(present_key->DataRaw()), - nullptr == present_value ? nullptr : reinterpret_cast(present_value->DataRaw())); - params.out_buffer = reinterpret_cast(output->MutableDataRaw()); - - if (key_padding_mask != nullptr) { - params.mask_index_buffer = key_padding_mask->Data(); - params.mask_index_dims = key_padding_mask->Shape().AsShapeVector(); - } - - if (attention_bias != nullptr) { - params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); - } - - params.workspace_buffer = reinterpret_cast(workspace.get()); - return (*std::static_pointer_cast(tunable_op_))(¶ms); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h deleted file mode 100644 index 1d676d7a7bcac..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "core/providers/rocm/rocm_kernel.h" -#include "contrib_ops/rocm/bert/attention_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class MultiHeadAttention final : public RocmKernel { - public: - MultiHeadAttention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - AttentionType attn_type_; - int num_heads_; // number of attention heads - float mask_filter_value_; - float scale_; - bool past_present_share_buffer_{false}; - bool is_unidirectional_{false}; - - // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: - // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. - // 2. We don't want to construct the object repeatly (which is expansive) during Compute. - std::shared_ptr tunable_op_; -}; - -template -class DecoderMaskedMultiHeadAttention final : public RocmKernel { - public: - DecoderMaskedMultiHeadAttention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - AttentionType mha_type; - int num_heads_; // number of attention heads - float mask_filter_value_; - float scale_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc deleted file mode 100644 index 9e649fb591896..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/skip_layer_norm.h" - -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" -#include "contrib_ops/rocm/bert/transformer_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - SkipLayerNormalization, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - SkipSimplifiedLayerNormalization, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -using namespace ONNX_NAMESPACE; - -template -SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) { - ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); - ORT_ENFORCE(epsilon_ >= 0); -} - -template -Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { - const Tensor* input = ctx->Input(0); - const Tensor* skip = ctx->Input(1); - const Tensor* gamma = ctx->Input(2); - - const Tensor* beta = Simplified ? nullptr : ctx->Input(3); - const Tensor* bias = Simplified ? ctx->Input(3) : ctx->Input(4); - - Tensor* output = ctx->Output(0, input->Shape()); - - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors - Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); - - if (input->Shape() != skip->Shape()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "skip is expected to have same shape as input"); - } - - if (input->Shape().Size() == 0) { - return Status::OK(); - } - - const auto& input_dims = input->Shape().GetDims(); - size_t input_dims_size = input_dims.size(); - if (input_dims_size != 3 && input_dims_size != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 3 or 2 dimensions, got ", input_dims_size); - } - - int hidden_size = static_cast(input_dims[input_dims_size - 1]); - - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of gamma and input does not match"); - } - - if (nullptr != beta) { - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of beta and input does not match"); - } - } - - if (nullptr != bias) { - const auto& bias_dims = bias->Shape().GetDims(); - if (bias_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "bias is expected to have 1 dimension, got ", bias_dims.size()); - } - if (bias_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of bias and input does not match"); - } - } - - int64_t element_count = input->Shape().Size(); - typedef typename ToHipType::MappedType HipT; - - return LaunchSkipLayerNormKernel( - GetTuningContext(), - ctx->GetComputeStream(), - reinterpret_cast(output->MutableData()), - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, - reinterpret_cast(input->Data()), - reinterpret_cast(skip->Data()), - reinterpret_cast(gamma->Data()), - (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, - epsilon_, - hidden_size, - static_cast(element_count)); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h deleted file mode 100644 index 02228bc59cedc..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class SkipLayerNorm final : public RocmKernel { - public: - SkipLayerNorm(const OpKernelInfo& op_kernel_info); - Status ComputeInternal(OpKernelContext* context) const override; - - private: - float epsilon_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu deleted file mode 100644 index 8387c49a3310b..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu +++ /dev/null @@ -1,86 +0,0 @@ -#include "hip/hip_runtime.h" -/* - The implementation of this file is based on skipLayerNorm plugin in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// Modifications: Add SkipLayerNormKernelVec to -// leverage vectorized load/write. -// and templatize ComputeSkipLayerNorm for different -// data types. -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" - -#include - -#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" -#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, - const T* skip, const V* gamma, const V* beta, const T* bias, float epsilon, int ld, int element_count) { - // this must be true because element_count is the total size of the tensor - assert(element_count % ld == 0); - - SkipLayerNormParams params(tuning_ctx, stream, output, skip_input_bias_add_output, input, skip, - gamma, beta, bias, epsilon, ld, element_count); - - if (tuning_ctx->IsTunableOpEnabled()) { - static SkipLayerNormTunableOp op; - return op(¶ms); - } - - return SkipLayerNormStaticSelection(¶ms); -} - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, - const float* skip, const float* gamma, const float* beta, - const float* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, - const half* skip, const half* gamma, const half* beta, - const half* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, - const float* skip, const float* gamma, const float* beta, - const float* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, - const half* skip, const half* gamma, const half* beta, - const half* bias, float epsilon, int ld, - int element_count); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h deleted file mode 100644 index 5e2a92447d2f5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning, - Stream* stream, - V* output, // output tensor - T* skip_input_bias_add_output, // optional output tensor - const T* input, // input tensor - const T* skip, // skip tensor - const V* gamma, // Layer normalization gamma tensor - const V* beta, // Layer normalization beta tensor - const T* bias, // Layer normalization beta tensor - float epsilon, // Layer normalization epsilon - int hidden_size, // hidden size, it is the leading dimension (ld) - int element_count // number of elements in input tensor -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h deleted file mode 100644 index fcfbc8969e498..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "contrib_ops/rocm/bert/layer_norm.cuh" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -T maybe2half(float x); - -template <> -float maybe2half(float x) { - return x; -} - -template <> -half maybe2half(float x) { - return __float2half_rn(x); -} - -template -__global__ void SkipLayerNormKernel( - const int ld, const T* input, const T* skip, const V* beta, const V* gamma, const T* bias, - const U epsilon, V* output, T* skip_input_bias_add_output) { - const U reverse_ld = U(1.f / ld); - const int offset = blockIdx.x * ld; - - KeyValuePairSum pair_sum; - // reduce x and x^2 - hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const U val = (bias == nullptr) ? static_cast(input[idx]) + static_cast(skip[idx]) : static_cast(input[idx]) + static_cast(skip[idx]) + static_cast(bias[i]); - const U rldval = reverse_ld * val; - thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); - - if (skip_input_bias_add_output != nullptr) { - skip_input_bias_add_output[idx] = static_cast(val); - } - - output[idx] = static_cast(val); - } - - if constexpr (Simplified) { - SimplifiedLayerNorm(thread_data.value, ld, offset, gamma, epsilon, output); - return; - } - - LayerNorm(thread_data, ld, offset, beta, gamma, epsilon, output); -} - -// Vectorized kernel -template -__global__ void SkipLayerNormKernelVec( - const int ld, const T* input, const T* skip, const V* beta, const V* gamma, - const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput) { - const U reverse_ld = U(1.f / ld); - const int offset = blockIdx.x * ld; - - KeyValuePairSum pair_sum; - // reduce x and x^2 - hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); - - using VecT = aligned_vector; - using VecV = aligned_vector; - if (threadIdx.x * ILP < ld) { - for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { - int idx = offset + i; - - const VecT input_v = *reinterpret_cast(input + idx); - const VecT skip_v = *reinterpret_cast(skip + idx); - const VecT bias_v = hasBias ? *reinterpret_cast(bias + i) : VecT(); - VecT skip_input_bias_add_output_v, output_v; - -#pragma unroll - for (int k = 0; k < ILP; k++) { - const U val = hasBias ? static_cast(input_v.val[k]) + static_cast(skip_v.val[k]) + static_cast(bias_v.val[k]) : static_cast(input_v.val[k]) + static_cast(skip_v.val[k]); - const U rldval = reverse_ld * val; - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v.val[k] = static_cast(val); - } - thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); - output_v.val[k] = static_cast(val); - } - - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(skip_input_bias_add_output + idx)) = skip_input_bias_add_output_v; - } - - *(reinterpret_cast(output + idx)) = output_v; - } - } - - if constexpr (Simplified) { - SimplifiedLayerNormVec(thread_data.value, ld, offset, gamma, epsilon, output); - return; - } - - LayerNormVec(thread_data, ld, offset, beta, gamma, epsilon, output); -} - -// Vectorized kernel -template -__global__ void SkipLayerNormKernelSmall( - const int ld, const T* input, const T* skip, const V* beta, const V* gamma, - const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput) { - const U rld = U(1.f / ld); - const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld - - using VecT = aligned_vector; - hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); - - VecT input_v; - if (ILP * threadIdx.x < ld) { - input_v = *reinterpret_cast(input + idx); - const VecT skip_v = *reinterpret_cast(skip + idx); - const VecT bias_v = hasBias ? *reinterpret_cast(bias + threadIdx.x * ILP) : VecT(); - VecT skip_input_bias_add_output_v; - - U rldval_sum = U(0.f); - U rldvalsq_sum = U(0.f); -#pragma unroll - for (int i = 0; i < ILP; i++) { - const U val = hasBias ? static_cast(input_v.val[i]) + static_cast(skip_v.val[i]) + static_cast(bias_v.val[i]) : static_cast(input_v.val[i]) + static_cast(skip_v.val[i]); - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v.val[i] = static_cast(val); - } - - const U rldval = rld * val; - rldval_sum += rldval; - rldvalsq_sum += rldval * val; - input_v.val[i] = static_cast(val); - } - - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(skip_input_bias_add_output + idx)) = skip_input_bias_add_output_v; - } - - thread_data = hipcub::KeyValuePair(rldval_sum, rldvalsq_sum); - } - - if constexpr (Simplified) { - SimplifiedLayerNormSmall(input_v.val, thread_data.value, ld, idx, gamma, epsilon, output); - return; - } - - LayerNormSmall(input_v.val, thread_data, ld, idx, beta, gamma, epsilon, output); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h deleted file mode 100644 index 0391704ce1c56..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include - -#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::CeilDiv; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -struct SkipLayerNormParams : OpParams { - SkipLayerNormParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, - const T* skip, const V* gamma, const V* beta, - const T* bias, float epsilon, int ld, int element_count) - : OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {} - - std::string Signature() const override { - std::string sig = std::to_string(ld) + "_" + std::to_string(element_count); - return sig; - } - - V* output; - T* skip_input_bias_add_output; - const T* input; - const T* skip; - const V* gamma; - const V* beta; - const T* bias; - float epsilon; - int ld; - int element_count; -}; - -template -Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { - // Loosen the hard constraint for ld (hidden_size) to include more possible *Small kernels, - // which could offer better performance in some combinations of ThreadsPerBlock and VecSize. - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->ld <= 8192 && params->ld % VecSize == 0 && - params->ld <= ThreadsPerBlock * VecSize && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))); - SkipLayerNormKernelSmall<<element_count, params->ld)), - dim3(ThreadsPerBlock), - 0, params->StreamHandle()>>>( - params->ld, params->input, params->skip, - params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, - (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); - return HIP_CALL(hipGetLastError()); -} - -template -Status SkipLayerNormRegularOp(const SkipLayerNormParams* params) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->ld > 0 && params->ld % VecSize == 0 && - (params->ld >= ThreadsPerBlock * VecSize || - (params->ld < GPU_WARP_SIZE && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))))); - SkipLayerNormKernelVec<<element_count, params->ld)), - dim3(ThreadsPerBlock), - 0, params->StreamHandle()>>>( - params->ld, params->input, params->skip, - params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, - (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); - return HIP_CALL(hipGetLastError()); -} - -template -Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { - bool hasBias = (params->bias == nullptr) ? false : true; - bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true; - const int grid_size = params->element_count / params->ld; - const int block_size = 256; - -#define LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(ELEMENTS, TPB, ILP) \ - if (params->ld <= ELEMENTS) { \ - SkipLayerNormKernelSmall<<StreamHandle()>>>( \ - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, \ - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, \ - hasBias, hasSkipInputBiasAdditionOutput); \ - break; \ - } - if (0 == (params->ld % 4)) { - do { - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(32, 32, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(64, 32, 2) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 32, 4) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 96, 4) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(768, 192, 4) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(1024, 256, 4) - - SkipLayerNormKernel<<StreamHandle()>>>( - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); - } while (0); - } else { - do { - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(32, 32, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(64, 64, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 128, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 384, 1) - - SkipLayerNormKernel<<StreamHandle()>>>( - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); - } while (0); - } - return HIP_CALL(hipPeekAtLastError()); -} // namespace rocm - -#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); - -#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 320) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 384) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 448) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 512) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 576) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 640) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 704) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 768) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 832) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 896) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 1024) - -template -class SkipLayerNormTunableOp : public TunableOp> { - public: - SkipLayerNormTunableOp() { - this->RegisterOp(SkipLayerNormStaticSelection); - ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmallOp) - ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegularOp) - - // NOTE: the 1st kernel is SkipLayerNorm Original implementation. - this->SetDefaultId(0); - } -}; - -#undef ADD_OP_FOR_ALL_VEC_SIZE -#undef ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc b/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc deleted file mode 100644 index 6ae8d1202d462..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#include -#include "core/providers/shared_library/provider_api.h" // Include this otherwise Windows build complains Env::Default() missing -#include "core/platform/env_var_utils.h" -#include "contrib_ops/rocm/bert/transformer_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -// The environment variable is for testing purpose only, and it might be removed in the future. -// If you need some option in production, please file a feature request. -constexpr const char* kTransformerOptions = "ORT_TRANSFORMER_OPTIONS"; - -// Initialize the singleton instance -TransformerOptions TransformerOptions::instance; - -const TransformerOptions* TransformerOptions::GetInstance() { - if (!instance.initialized_) { - // We do not use critical section here since it is fine to initialize multiple times by different threads. - int value = ParseEnvironmentVariableWithDefault(kTransformerOptions, 0); - instance.Initialize(value); - - if (value > 0) - std::cout << "ORT_TRANSFORMER_OPTIONS: IsPrecisionMode=" << instance.IsPrecisionMode() - << ",DisablePersistentSoftmax=" << instance.DisablePersistentSoftmax() - << ",DisableHalf2=" << instance.DisableHalf2() - << std::endl; - } - - return &instance; -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/transformer_common.h b/onnxruntime/contrib_ops/rocm/bert/transformer_common.h deleted file mode 100644 index 6816b5b9d07ec..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/transformer_common.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -class TransformerOptions { - public: - static const TransformerOptions* GetInstance(); - - bool IsPrecisionMode() const { return is_precision_mode_; } - - bool DisablePersistentSoftmax() const { return disable_persistent_softmax_; } - - bool DisableHalf2() const { return disable_half2_; } - - void Initialize(int value) { - is_precision_mode_ = (value & 0x01) > 0; - disable_persistent_softmax_ = (value & 0x02) > 0; - disable_half2_ = (value & 0x04) > 0; - initialized_ = true; - } - - private: - // Default is false. If the mode is on, prefer precision than speed. - bool is_precision_mode_{false}; - - // Disable persistent softmax. - bool disable_persistent_softmax_{false}; - - // Disable half2 kernel. - bool disable_half2_{false}; - - bool initialized_{false}; - - static TransformerOptions instance; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh deleted file mode 100644 index d0a0d09fcbae3..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#ifdef USE_COMPOSABLE_KERNEL -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" -#endif // USE_COMPOSABLE_KERNEL - -#include "contrib_ops/rocm/diffusion/group_norm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#ifdef USE_COMPOSABLE_KERNEL - -using onnxruntime::rocm::CKDataTypeAdaptor; - -// The SiLU function is a special case of Swish function, -// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as: -// SiLU(x) = x * sigmoid(x) -// Swish(x) = x * sigmoid(bx) -// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here. -using Silu = ck::tensor_operation::element_wise::Swish; -using Pass = ck::tensor_operation::element_wise::PassThrough; - -constexpr int Rank = 5; -constexpr int NumReduceDim = 3; - -template -auto GetCKGroupNormNHWCTypeStringAndOps() { - using XDataType = typename CKDataTypeAdaptor::type; - using YDataType = typename CKDataTypeAdaptor::type; - using SaveMeanInvStdDataType = typename CKDataTypeAdaptor::type; - using GammaDataType = float; - using BetaDataType = float; - - using Activation = std::conditional_t; - - std::vector>>> ret; - for (auto&& impl : internal::GetDeviceGroupNormInstances()) { - std::string silu_suffix = WithSilu ? "_Silu" : "_Pass"; - auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix; - auto invoker = impl->MakeInvokerPointer(); - - auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( - const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), - "Input skip or bias is not supported by composable kernel."); - if constexpr (WithSilu) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->use_silu, "Silu version only support groupnorm with silu"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->use_silu, "Pass version only support groupnorm without silu"); - } - std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; - std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, - params->c, params->channels_per_group, 1}; - std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; - std::vector reduce_dims{1, 2, 4}; - - auto activation = Activation{}; - - auto arg = impl->MakeArgumentPointer(in_lengths, // lengths - in_out_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - in_out_strides, // yStrides - {0, 0}, // saveMeanStrides - {0, 0}, // saveInvStdStrides - reduce_dims, // reduceDims - params->epsilon, - params->src, - params->gamma, - params->beta, - params->dst, - nullptr, - nullptr, - activation); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_group_norm_op))); - } - return ret; -} -#endif // USE_COMPOSABLE_KERNEL - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh deleted file mode 100644 index 68f7d47282845..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ /dev/null @@ -1,130 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#ifdef USE_COMPOSABLE_KERNEL -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" -#include "ck/utility/data_type.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using F16 = ck::half_t; -using F32 = float; - -using Silu = ck::tensor_operation::element_wise::Swish; -using Pass = ck::tensor_operation::element_wise::PassThrough; - -using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface -using ck::tensor_operation::device::DeviceNormalizationFwdImpl; // the implementation - -// See https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/1fefd82ed8/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp - -template -using device_normalization_f32_instances = std::tuple< - // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl - // clang-format on - >; - -template -using device_normalization_f16_instances = - // clang-format off - std::tuple < - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl - // clang-format on - >; - -// Use this function to get implementation -template -std::vector>> -GetDeviceGroupNormInstances() { - return {}; -} - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Silu, 5, 3>(); - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Pass, 5, 3>(); - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Silu, 5, 3>(); - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Pass, 5, 3>(); - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu deleted file mode 100644 index ad191314e5e4c..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f16_instances{}); - - return instances; -} - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f16_instances{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu deleted file mode 100644 index ceb53ed442abc..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f32_instances{}); - - return instances; -} - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f32_instances{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h deleted file mode 100644 index 7cff640db2f34..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/rocm/diffusion/group_norm_common_base.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { - GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx, - onnxruntime::Stream* ort_stream, - T* output, - T* add_out, - const T* input, - const T* skip, - const T* bias, - const float* gamma, - const float* beta, - float* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_silu, - bool broadcast_skip, - int channels_per_block) - : OpParams(tuning_ctx, ort_stream), - GroupNormNHWCParams(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, - num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {} - - std::string Signature() const override { - std::string silu_suffix = this->use_silu ? "_silu" : "_pass"; - std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip"; - std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast"; - std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias"; - std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + - std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix + - skip_suffix + broadcast_suffix + bias_suffix; - return sig; - } -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu deleted file mode 100644 index 142aaf14e8d2d..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// The ROCM kernel is hipified from CUDA kernel. -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" - -#include -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_tunable_op.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchGroupNormKernel( - RocmTuningContext* tuning_ctx, - Stream* ort_stream, - T* output, - T* add_out, - const T* input, - const T* skip, - const T* bias, - const float* gamma, - const float* beta, - void* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_silu, - bool broadcast_skip, - int channels_per_block) { - GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, - reinterpret_cast(workspace), epsilon, batch_size, num_channels, - height, width, num_groups, use_silu, broadcast_skip, channels_per_block); - - if (params.channels_per_block % params.channels_per_group != 0 || - params.channels_per_block > kMaxSize || - (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "GroupNorm in ROCM does not support the input: n=", batch_size, - " h=", height, - " w=", width, - " c=", num_channels, - " groups=", num_groups); - } - - HIP_RETURN_IF_ERROR(hipMemsetAsync( - params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle())); - - if (tuning_ctx->IsTunableOpEnabled()) { - static GroupNormNHWCTunableOp op; - return op(¶ms); - } - - return GroupNormNHWCStaticSelection(¶ms); -} - -template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, - half* add_out, const half* input, const half* skip, const half* bias, - const float* gamma, const float* beta, void* workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, - bool use_silu, bool broadcast_skip, int channels_per_block); - -template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, - float* add_out, const float* input, const float* skip, const float* bias, - const float* gamma, const float* beta, void* workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, - bool use_silu, bool broadcast_skip, int channels_per_block); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh deleted file mode 100644 index c6ca16bfdfc80..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "core/providers/rocm/triton_kernel.h" - -using namespace onnxruntime::rocm; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#ifdef USE_TRITON_KERNEL - -namespace { - -template -std::string GetGroupNormTritonGroupName() { - std::string ret = "GroupNormTriton_"; - std::string silu_suffix = WithSilu ? "Silu_" : "Pass_"; - ret += silu_suffix; - ret += GetDataTypeName(); - return ret; -} - -} // namespace - -template -auto GetTritonGroupNormNHWCTypeStringAndOps() { - std::vector>>> ret; - auto group_name = GetGroupNormTritonGroupName(); - auto* kernel_list = GetOrtTritonKernelByGroup(group_name); - if (kernel_list == nullptr) { - return ret; - } - - for (auto i : *kernel_list) { - // Check params match - auto* metadata = GetOrtTritonKernelMetadata(i); - auto block_size = metadata->constants.at("BLOCK_SIZE"); - auto hw_size = metadata->constants.at("HW_SIZE"); - auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, - "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", - params->channels_per_group, ")."); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); - if constexpr (WithSilu) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu."); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu."); - } - // Construct args for launch kernel - struct { - const void* src; - const void* skip; - const void* bias; - void* out; - void* add_out; - const void* gamma; - const void* beta; - int hw; - int c; - int c_per_group; - float eps; - bool has_skip; - bool has_bias; - bool broadcast_skip; - } args = { - (const void*)params->src, - (const void*)params->skip, - (const void*)params->bias, - (void*)params->dst, - (void*)params->skip_workspace, - (const void*)params->gamma, - (const void*)params->beta, - params->hw, - params->c, - params->channels_per_group, - params->epsilon, - params->skip != nullptr, - params->bias != nullptr, - params->broadcast_skip, - }; - - // Grid dim is (batch_count, groups, 1) - return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); - }; - ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); - } - return ret; -} - -#endif // USE_TRITON_KERNEL - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py deleted file mode 100644 index 5ba96ebc117f0..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ /dev/null @@ -1,135 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from itertools import product - -import triton -import triton.language as tl - - -@triton.jit -def group_norm_kernel( - input_ptr, - skip_ptr, - bias_ptr, - output_ptr, - add_out_ptr, - gamma_ptr, - beta_ptr, - img_size, - c, - c_per_group, - eps, - has_skip, - has_bias, - broadcast_skip, - BLOCK_SIZE: tl.constexpr, - HW_SIZE: tl.constexpr, - ACTIVATION_SILU: tl.constexpr, -): - row_x = tl.program_id(0) - row_y = tl.program_id(1) - stride = img_size * c - input_ptr += row_x * stride + row_y * c_per_group - output_ptr += row_x * stride + row_y * c_per_group - gamma_ptr += row_y * c_per_group - beta_ptr += row_y * c_per_group - - cols = tl.arange(0, BLOCK_SIZE) - hw = tl.arange(0, HW_SIZE) - offsets = hw[:, None] * c + cols[None, :] - mask = (cols < c_per_group)[None, :] - - bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - if has_skip: - add_out_ptr += row_x * stride + row_y * c_per_group - if broadcast_skip: - broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group - bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) - else: - skip_ptr += row_x * stride + row_y * c_per_group - if has_bias: - bias_ptr += row_y * c_per_group - bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) - - # Calculate mean and variance - _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) - _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) - for i in range(tl.cdiv(img_size, HW_SIZE)): - x_ptr = input_ptr + i * HW_SIZE * c - a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - if has_skip and not broadcast_skip: - s_ptr = skip_ptr + i * HW_SIZE * c - s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - a += s - if has_bias or broadcast_skip: - a += bias - _sum += a - _square_sum += a * a - if has_skip: - add_y_ptr = add_out_ptr + i * HW_SIZE * c - tl.store(add_y_ptr + offsets, a, mask=mask) - - # Set axis=None (or leave it unspecified) to reduce all axes. - # TODO: In older Triton we have to reduce an axis at a time, but in our case - # for some configs it may have some issue when reducing sequentially along the axes. - group_mean = tl.sum(_sum, axis=None) / (img_size * c_per_group) - group_var = tl.sum(_square_sum, axis=None) / (img_size * c_per_group) - group_mean * group_mean - - rstd = 1 / tl.sqrt(group_var + eps) - - # Normalize and apply linear transformation - gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) - beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) - for i in range(tl.cdiv(img_size, HW_SIZE)): - y_ptr = output_ptr + i * HW_SIZE * c - if has_skip: - add_y_ptr = add_out_ptr + i * HW_SIZE * c - x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - else: - x_ptr = input_ptr + i * HW_SIZE * c - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - x_hat = (x - group_mean) * rstd - y = x_hat * gamma + beta - if ACTIVATION_SILU: - y *= tl.sigmoid(y) - tl.store(y_ptr + offsets, y, mask=mask) - - -# We can have more combinations of blocks and hw_sizes, e.g., -# blocks = [16, 32, 64, 128, 256, 512] -# hw_sizes = [8, 16, 32, 64, 128, 256, 512] -# but this will result in too many functions and slow down the compilation. -with_silu = [True, False] -dtypes = ["fp32", "fp16"] -blocks = [16, 32, 64, 128] -hw_sizes = [8, 16, 32, 64, 128, 256] -warps = [1, 2, 4, 8, 16] -name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" -group_pattern = "GroupNormTriton_{}_{}" - - -def get_function_table(): - func_table = [] - - for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks): - silu_suffix = "Silu" if silu else "Pass" - name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) - group = group_pattern.format(silu_suffix, dtype) - sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) - kwargs = { - "num_warps": warp, - "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, - } - func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs} - func_table.append(func_desc) - return func_table - - -if __name__ == "__main__": - func_table = get_function_table() - for func_desc in func_table: - print(func_desc) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h deleted file mode 100644 index e6831f764b418..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh" -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh" -#include "contrib_ops/rocm/diffusion/group_norm_triton.cuh" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using onnxruntime::rocm::GPU_WARP_SIZE; - -template -void GroupNormNHWCSum(const GroupNormNHWCTunableParams* params) { - dim3 grid; - - // The number of blocks to compute all the channels. - grid.x = DivUp(params->c, params->channels_per_block); - // The number of blocks to compute all the activations in a given instance. - grid.y = DivUp(params->hw, params->hw_per_block); - // The number of instances. - grid.z = params->n; - -#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ - GroupNormNHWCSumKernel \ - <<StreamHandle()>>>( \ - params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \ - params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \ - params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \ - break; - - // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. - switch (params->threads_per_block) { - case 256: - LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) - case 192: - LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) - case 160: - LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) - case 128: - LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) - case 64: - LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) - default: - ORT_NOT_IMPLEMENTED("Not implemented"); - } -} - -template -Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams* params) { - dim3 grid; - grid.x = DivUp(params->c, params->channels_per_block); - grid.y = DivUp(params->hw, params->hw_per_block); - grid.z = params->n; - - GroupNormNHWCSumKernel - <<StreamHandle()>>>( - params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, - params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, - params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); - return HIP_CALL(hipGetLastError()); -} - -template -void GroupNormNHWCScale(const GroupNormNHWCTunableParams* params) { - dim3 grid; - - // The number of blocks to compute all the channels. - grid.x = DivUp(params->c, params->channels_per_block); - // The number of blocks to compute all the activations in a given instance. - grid.y = DivUp(params->hw, params->hw_per_block); - // The number of instances. - grid.z = params->n; - -#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ - GroupNormNHWCScaleKernel \ - <<StreamHandle()>>>( \ - params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ - params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ - params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ - params->hw, params->hw_per_block, params->use_silu); \ - break; - - // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. - switch (params->threads_per_block) { - case 256: - LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) - case 192: - LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) - case 160: - LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) - case 128: - LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) - case 64: - LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) - default: - ORT_NOT_IMPLEMENTED("Not implemented"); - } -} - -template -Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams* params) { - dim3 grid; - grid.x = DivUp(params->c, params->channels_per_block); - grid.y = DivUp(params->hw, params->hw_per_block); - grid.z = params->n; - - GroupNormNHWCScaleKernel - <<StreamHandle()>>>( - params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, - params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, - params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, - params->use_silu); - return HIP_CALL(hipGetLastError()); -} - -template -class GroupNormNHWCOp { - public: - Status operator()(const GroupNormNHWCTunableParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, - 0, - GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), - params->StreamHandle())); - auto status = GroupNormNHWCSumOp(params); - ORT_RETURN_IF_ERROR(status); - HIP_RETURN_IF_ERROR(hipGetLastError()); - status = GroupNormNHWCScaleOp(params); - ORT_RETURN_IF_ERROR(status); - HIP_RETURN_IF_ERROR(hipGetLastError()); - return Status::OK(); - } - - Status IsSupported(const GroupNormNHWCTunableParams* params) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0), - "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group, - ") isn't divisible by the number of vector size: ", VecSize); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize && - params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), - "Configuration: Threads (", ThreadsPerBlock, "), vector size (", - VecSize, ") is redundant for the number of channels per group: ", - params->channels_per_block); - - return Status::OK(); - } -}; - -template -Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, - 0, - GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), - params->StreamHandle())); - GroupNormNHWCSum(params); - HIP_RETURN_IF_ERROR(hipGetLastError()); - GroupNormNHWCScale(params); - HIP_RETURN_IF_ERROR(hipGetLastError()); - return Status::OK(); -} - -#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ - this->RegisterOp(name{}); \ - this->RegisterOp(name{}); \ - this->RegisterOp(name{}); - -#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 320) - -template -class GroupNormNHWCTunableOp : public TunableOp> { - public: - GroupNormNHWCTunableOp() { - this->RegisterOp(GroupNormNHWCStaticSelection); - ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) - -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif // USE_COMPOSABLE_KERNEL - -#ifdef USE_TRITON_KERNEL - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - } -}; - -#undef ADD_OP_FOR_ALL_VEC_SIZE -#undef ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc b/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc deleted file mode 100644 index 35427a02c631d..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/nn/conv.h" - -using namespace onnxruntime::rocm; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - NhwcConv, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/fused_conv.cc b/onnxruntime/contrib_ops/rocm/fused_conv.cc deleted file mode 100644 index 4f3be98d97f80..0000000000000 --- a/onnxruntime/contrib_ops/rocm/fused_conv.cc +++ /dev/null @@ -1,439 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include "core/common/status.h" -#include "core/providers/rocm/nn/conv.h" -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -namespace { - -// Copied from hipDNN/library/src/hcc_detail/hipdnn_miopen.cpp -miopenStatus_t _miopenAddTensor( - miopenHandle_t handle, - const void* alpha, - const miopenTensorDescriptor_t aDesc, - const void* A, - const void* beta, - const miopenTensorDescriptor_t cDesc, - void* C, - const void* zero_scalar) { - const miopenTensorOp_t tensorOp = miopenTensorOpAdd; - // Using miopenOpTensor to implement Add operator. - // opnd2 = Add ( 0.0 * opnd0, alpha * opnd1 ) + beta * opnd2 - return miopenOpTensor(handle, tensorOp, - zero_scalar, cDesc, C, - alpha, aDesc, A, - beta, cDesc, C); -} - -} // namespace - -template -struct FNVHash { - uint32_t GetValue() const { return value_; } - - void Hash(const void* in_ptr, size_t nbytes) { - auto ptr = reinterpret_cast(in_ptr); - for (size_t i = 0; i < nbytes; ++i) { - value_ ^= ptr[i]; - value_ *= PRIME; - } - } - - template ::value, size_t>::type = 0> - FNVHash& operator<<(const T& pod) { - Hash(&pod, sizeof(pod)); - return *this; - } - - template - FNVHash& operator<<(const std::vector& pod_array) { - for (const auto& pod : pod_array) { - (*this) << pod; - } - return *this; - } - - void HashTensor(miopenTensorDescriptor_t tdesc) { - int size = 0; - miopenGetTensorDescriptorSize(tdesc, &size); - (*this) << size; - std::vector dims(size); - std::vector strides(size); - miopenDataType_t dtype; - miopenGetTensorDescriptor(tdesc, &dtype, dims.data(), strides.data()); - (*this) << dtype; - (*this) << dims; - (*this) << strides; - } - - void HashConvolutionDescriptor(miopenConvolutionDescriptor_t cdesc) { - int spatial_dim = 1; -#if ROCM_VERSION >= 50500 - MIOPEN_CALL(miopenGetConvolutionSpatialDim(cdesc, &spatial_dim)); - std::vector pads{spatial_dim}; - std::vector strides{spatial_dim}; - std::vector dilations{spatial_dim}; - miopenConvolutionMode_t mode; - MIOPEN_CALL(miopenGetConvolutionNdDescriptor(cdesc, spatial_dim, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)); -#else - // Previous versions of MIOpen doesn't provide API to probe the dimension of a - // miopenConvolutionDescriptor_t, so we have to guess. - // This algorithm is based on a specific behavior of miopenGetConvolutionNdDescriptor, - // which fails when requestedSpatialDim > the convolution's spatial dimension - constexpr const int kMaxSpatialDim = 5; - std::vector pads{kMaxSpatialDim}; - std::vector strides{kMaxSpatialDim}; - std::vector dilations{kMaxSpatialDim}; - miopenConvolutionMode_t mode; - bool spatial_dim_guessed = false; - for (int i = 0; i < kMaxSpatialDim; i++) { - if (miopenStatusSuccess == miopenGetConvolutionNdDescriptor( - cdesc, i, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)) { - spatial_dim_guessed = true; - break; - } - } - ORT_ENFORCE(spatial_dim_guessed, "Failed to guess the actual spatial dimension"); - // Remove the extra dimension - pads.resize(spatial_dim); - strides.resize(spatial_dim); - dilations.resize(spatial_dim); -#endif - (*this) << spatial_dim; - (*this) << pads; - (*this) << strides; - (*this) << dilations; - (*this) << mode; - } - - private: - uint32_t value_ = BASIS; -}; - -template -class FusedConv : public onnxruntime::rocm::Conv { - public: - using Base = onnxruntime::rocm::Conv; - FusedConv(const OpKernelInfo& info) : onnxruntime::rocm::Conv(info) { - std::string activation; - ORT_THROW_IF_ERROR(info.GetAttr("activation", &activation)); - ORT_THROW_IF_ERROR(MapMode(activation)); - MIOPEN_CALL_THROW(miopenCreateActivationDescriptor(&activation_desc_)); - MIOPEN_CALL_THROW(miopenSetActivationDescriptor(activation_desc_, activation_mode_, 0.0, 0.0, 0.0)); - MIOPEN_CALL_THROW(miopenCreateOperatorArgs(&fusion_args_)); - } - - ORT_DISALLOW_COPY_AND_ASSIGNMENT(FusedConv); - - ~FusedConv() { - if (activation_desc_) { - MIOPEN_CALL_THROW(miopenDestroyActivationDescriptor(activation_desc_)); - activation_desc_ = nullptr; - } - - if (fusion_args_) { - miopenDestroyOperatorArgs(fusion_args_); - } - } - - Status ComputeInternal(OpKernelContext* context) const override { - std::lock_guard lock(Base::s_.mutex); - - ORT_RETURN_IF_ERROR(Base::UpdateState(context, true)); - if (Base::s_.Y->Shape().Size() == 0) { - return Status::OK(); - } - - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - auto factory = [this](FusedConvFusionData& fusion) { - return this->DoCreateFusionDesc(this->Node().Name(), fusion); - }; - auto& cached_item = plan_cache_.FindOrCreateFusionPlanCache(Hash(), - factory); - bool should_try_fusion_api = cached_item.Validate(this->GetMiopenHandle(context)); - - typedef typename onnxruntime::rocm::ToHipType::MappedType HipT; - const auto alpha = onnxruntime::rocm::Consts::One; - const auto beta = onnxruntime::rocm::Consts::Zero; - IAllocatorUniquePtr workspace = Base::GetWorkSpace(context->GetComputeStream()); - miopenStatus_t fusion_status = miopenStatusNotInitialized; - - if (should_try_fusion_api) { - auto& fusion_info = *cached_item.fusion; - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsConvForward(fusion_args_, - fusion_info.conv_op, - &alpha, - &beta, - Base::s_.w_data)); - if (has_z) { - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_args_, - fusion_info.bias_z_op, - &alpha, - &beta, - Base::s_.z_data)); - } - if (has_b) { - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_args_, - fusion_info.bias_b_op, - &alpha, - &beta, - Base::s_.b_data)); - } - if (activation_desc_) { - const float relu_notused = 0.0; - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsActivForward(fusion_args_, - fusion_info.act_op, - &alpha, - &beta, - relu_notused, - relu_notused, - relu_notused)); - } - fusion_status = miopenExecuteFusionPlan(this->GetMiopenHandle(context), - fusion_info.plan, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.y_tensor, - Base::s_.y_data, - fusion_args_); - } - if (miopenStatusSuccess != fusion_status) { - MIOPEN_RETURN_IF_ERROR(miopenConvolutionForward(this->GetMiopenHandle(context), - &alpha, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.w_desc, - Base::s_.w_data, - Base::s_.conv_desc, - Base::s_.fwd_algo, - &beta, - Base::s_.y_tensor, - Base::s_.y_data, - workspace.get(), - Base::s_.workspace_bytes)); - if (has_b) { - MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(this->GetMiopenHandle(context), - &alpha, Base::s_.b_tensor, Base::s_.b_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data, - &beta)); - } - if (has_z) { - MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(this->GetMiopenHandle(context), - &alpha, Base::s_.z_tensor, Base::s_.z_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data, - &beta)); - } - MIOPEN_RETURN_IF_ERROR(miopenActivationForward(this->GetMiopenHandle(context), - activation_desc_, - &alpha, - Base::s_.y_tensor, - Base::s_.y_data, - &beta, - Base::s_.y_tensor, - Base::s_.y_data)); - } - if (Base::s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(onnxruntime::rocm::SliceOutUnwantedOutputSection( - this->Stream(context), - Base::s_.y_data, - Base::s_.y_dims_with_adjusted_pads, - Base::s_.Y->MutableDataRaw(), - Base::s_.y_dims.GetDims(), - Base::s_.slice_starts, - Base::s_.slice_ends, - Base::s_.slice_axes, - Base::s_.element_size)); - } - return Status::OK(); - } - - private: - Status MapMode(const std::string& activaton_mode) { - if (activaton_mode == "Relu") { - activation_mode_ = miopenActivationMode_t::miopenActivationRELU; - } else { - return ORT_MAKE_STATUS( - StatusCategory::ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, - "unsupported conv activation mode \"", activaton_mode, "\""); - } - return Status::OK(); - } - miopenActivationMode_t activation_mode_; - miopenActivationDescriptor_t activation_desc_ = nullptr; - - miopenOperatorArgs_t fusion_args_ = nullptr; - - // MIOpen Fusion API - // TODO: create one fusion descriptor shared by multiple FusedConv - // objects - // - // Considerations: - // How to determine two FusedConv objects may share the same fusion - // descriptor? Hashing x_tensor,conv_desc, etc.? - struct FusedConvFusionData { - miopenFusionPlanDescriptor_t plan = nullptr; - miopenFusionOpDescriptor_t conv_op = nullptr; - miopenFusionOpDescriptor_t bias_b_op = nullptr; - miopenFusionOpDescriptor_t bias_z_op = nullptr; - miopenFusionOpDescriptor_t act_op = nullptr; - - // TODO: There is a potential problem. miopenHandle_t may be destroyed and - // re-created later, sharing the same address. Currently there is any way - // to detect it? - mutable std::unordered_set compiled_on; - - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FusedConvFusionData); - - FusedConvFusionData() {} - ~FusedConvFusionData() { - if (plan) { - miopenDestroyFusionPlan(plan); - } - } - }; - - struct FusionPlanCacheItem { - std::unique_ptr fusion; - Status creation_result; - // TODO: Add a timestamp for eviction - // std::chrono::time_point last_access; - - FusionPlanCacheItem() {} - - miopenStatus_t CompileOnHandle(miopenHandle_t handle) const { - if (!fusion->plan) { - return miopenStatusNotInitialized; - } - auto iter = fusion->compiled_on.find(handle); - if (iter != fusion->compiled_on.end()) { - return miopenStatusSuccess; - } - auto ret = miopenCompileFusionPlan(handle, fusion->plan); - if (miopenStatusSuccess == ret) { - fusion->compiled_on.insert(handle); - } else { - return ret; - } - return miopenStatusSuccess; - } - - bool Validate(miopenHandle_t handle) const { - if (Status::OK() != creation_result) { - return false; - } - if (!fusion || !fusion->plan) { - return false; - } - auto compiling_status = CompileOnHandle(handle); - if (miopenStatusSuccess != compiling_status) { - return false; - } - - return true; - } - }; - - struct FusionPlanCache { - mutable std::mutex mutex; - using HashKey = uint32_t; - std::unordered_map cache_directory_; - - FusionPlanCache() { - } - - FusionPlanCacheItem& FindOrCreateFusionPlanCache(HashKey key, - std::function factory) { - std::lock_guard lock(mutex); - auto iter = cache_directory_.find(key); - if (iter == cache_directory_.end()) { - cache_directory_[key].fusion = std::make_unique(); - cache_directory_[key].creation_result = factory(*cache_directory_[key].fusion); - if (Status::OK() != cache_directory_[key].creation_result) { - cache_directory_[key].fusion.reset(); - } - } - return cache_directory_[key]; - } - }; - - static FusionPlanCache plan_cache_; - - Status DoCreateFusionDesc(const std::string& node_name, FusedConvFusionData& fusion) const { - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - MIOPEN_RETURN_IF_ERROR(miopenCreateFusionPlan(&fusion.plan, - miopenVerticalFusion, - Base::s_.x_tensor)); - auto status = miopenCreateOpConvForward(fusion.plan, &fusion.conv_op, Base::s_.conv_desc, Base::s_.w_desc); - if (status == miopenStatusUnsupportedOp) { - auto msg = MakeString("MIOpen does not support the conv fusion for node \"", - node_name, "\", fallback to unfused implementation."); - LOGS_DEFAULT(WARNING) << msg; - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, msg); - } - MIOPEN_RETURN_IF_ERROR(status); - - if (has_z) { - MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan, - &fusion.bias_z_op, - Base::s_.z_tensor)); - } - if (has_b) { - MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan, - &fusion.bias_b_op, - Base::s_.b_tensor)); - } - if (activation_desc_) { - MIOPEN_RETURN_IF_ERROR(miopenCreateOpActivationForward(fusion.plan, - &fusion.act_op, - activation_mode_)); - } - return Status::OK(); - } - - uint32_t Hash() const { - FNVHash hash; - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - hash.HashTensor(Base::s_.x_tensor); - hash.HashConvolutionDescriptor(Base::s_.conv_desc); - hash.HashTensor(Base::s_.w_desc); - if (has_z) { - hash.HashTensor(Base::s_.z_tensor); - } - if (has_b) { - hash.HashTensor(Base::s_.b_tensor); - } - if (activation_desc_) { - hash << static_cast(activation_mode_); - } - return hash.GetValue(); - } -}; - -template -typename FusedConv::FusionPlanCache FusedConv::plan_cache_; - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - FusedConv, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - FusedConv); - -REGISTER_KERNEL_TYPED(float); -REGISTER_KERNEL_TYPED(MLFloat16); -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu deleted file mode 100644 index 3539f32252944..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/common.h" -#include "core/common/float16.h" -#include "core/providers/rocm/rocm_kernel.h" -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; -using namespace onnxruntime::rocm::tunable::blas; - -class GemmFloat8 final : public RocmKernel { - public: - GemmFloat8(const OpKernelInfo& info) : RocmKernel(info) { - transA_ = info.GetAttrOrDefault("transA", 0); - transB_ = info.GetAttrOrDefault("transB", 0); - dtype_ = info.GetAttrOrDefault("dtype", onnx::TensorProto_DataType_FLOAT16); - alpha_ = info.GetAttrOrDefault("alpha", 1); - beta_ = info.GetAttrOrDefault("beta", 0); - } - Status ComputeInternal(OpKernelContext* ctx) const override; - - private: -#if !defined(DISABLE_FLOAT8_TYPES) - template - Status ComputeFp8Fp16Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* scaleA, const Tensor* B, Tensor* C) const; - template - Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const; - - template - [[nodiscard]] inline auto* GetOp() const { - using OpT = GemmFloat8TunableOp; - if (tunable_op_) { - return static_cast(tunable_op_.get()); - } - - auto create = std::make_unique(); // avoid new - tunable_op_ = std::shared_ptr(create.release(), [](void* ptr) { - auto release = std::unique_ptr(); // avoid delete - release.reset(static_cast(ptr)); - }); - - return static_cast(tunable_op_.get()); - } -#endif - - float alpha_; - float beta_; - bool transA_; - bool transB_; - int64_t dtype_; - - // fully type erased - mutable std::shared_ptr tunable_op_; -}; - -Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { -#if defined(DISABLE_FLOAT8_TYPES) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DISABLE_FLOAT8_TYPES"); -#else - const Tensor* A = ctx->Input(0); - const Tensor* B = ctx->Input(1); - const Tensor* C = ctx->Input(2); // bias - const Tensor* scale_a = ctx->Input(3); - const Tensor* scale_b = ctx->Input(4); - const Tensor* scale_y = ctx->Input(5); - - auto a_shape = A->Shape(); - auto b_shape = B->Shape(); - ORT_ENFORCE(a_shape.NumDimensions() == 2); - ORT_ENFORCE(b_shape.NumDimensions() == 2); - - auto m = !transA_ ? a_shape[0] : a_shape[1]; - auto k = !transA_ ? a_shape[1] : a_shape[0]; - ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatible - auto n = !transB_ ? b_shape[1] : b_shape[0]; - - TensorShapeVector output_shape = {m, n}; - Tensor* Y = ctx->Output(0, output_shape); - - ORT_ENFORCE(!transA_, "ROCm GemmFloat8 does not support input A transpose"); - ORT_ENFORCE(dtype_ == onnx::TensorProto_DataType_FLOAT16, "ROCm GemmFloat8 only supports output float16"); - ORT_ENFORCE(C == nullptr, "ROCm GemmFloat8 does not support bias input"); - ORT_ENFORCE(scale_y == nullptr, "ROCm GemmFloat8 does not support output scaling"); - - if (A->IsDataType()) { - return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); - } else if (A->IsDataType()) { - return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); - } else if (B->IsDataType()) { - return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); - } else if (B->IsDataType()) { - return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); - } - - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unhandled type combination of GemmFloat8"); -#endif -} - -#if !defined(DISABLE_FLOAT8_TYPES) -template -Status GemmFloat8::ComputeFp8Fp16Fp16( - OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* scale_a, const Tensor* B, Tensor* C) const { - ORT_ENFORCE(A->IsDataType() && scale_a->IsDataType() && B->IsDataType()); - - onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; - params.tuning_ctx = GetTuningContext(); - params.stream = ctx->GetComputeStream(); - params.handle = GetHipblasHandle(ctx); - params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - - params.m = m; - params.n = n; - params.k = k; - - params.a = static_cast(A->DataRaw()); - params.lda = transA_ ? m : k; - params.scale_a = alpha_; - params.scale_a_dev = static_cast(scale_a->DataRaw()); - - params.b = static_cast(B->DataRaw()); - params.ldb = transB_ ? k : n; - params.scale_b = 1.0f; // NOTE: not used - params.scale_b_dev = nullptr; // NOTE: not used - - params.c = static_cast(C->MutableDataRaw()); - params.ldc = n; - params.scale_c = 1.0f; // NOTE: not implemented - params.scale_c_dev = nullptr; // NOTE: not implemented - - if (!transA_ && !transB_) { - return (*GetOp())(¶ms); - } else if (transA_ && !transB_) { - ORT_NOT_IMPLEMENTED("transA is not implemented"); - } else if (!transA_ && transB_) { - ORT_NOT_IMPLEMENTED("transB is not implemented"); - } else if (transA_ && transB_) { - ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); - } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); -} - -template -Status GemmFloat8::ComputeFp16Fp8Fp16( - OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* B, const Tensor* scale_b, Tensor* C) const { - ORT_ENFORCE(A->IsDataType() && B->IsDataType() && scale_b->IsDataType()); - - onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; - params.tuning_ctx = GetTuningContext(); - params.stream = ctx->GetComputeStream(); - params.handle = GetHipblasHandle(ctx); - params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - - params.m = m; - params.n = n; - params.k = k; - - params.a = static_cast(A->DataRaw()); - params.lda = transA_ ? m : k; - params.scale_a = 1.0f; // NOTE: not used - params.scale_a_dev = nullptr; // NOTE: not used - - params.b = static_cast(B->DataRaw()); - params.ldb = transB_ ? k : n; - params.scale_b = alpha_; - params.scale_b_dev = static_cast(scale_b->DataRaw()); - - params.c = static_cast(C->MutableDataRaw()); - params.ldc = n; - params.scale_c = 1.0f; // NOTE: not implemented - params.scale_c_dev = nullptr; // NOTE: not implemented - - if (!transA_ && !transB_) { - return (*GetOp())(¶ms); - } else if (transA_ && !transB_) { - ORT_NOT_IMPLEMENTED("transA is not implemented"); - } else if (!transA_ && transB_) { - return (*GetOp())(¶ms); - } else if (transA_ && transB_) { - ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); - } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); -} -#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() -#else -#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() -#endif - -ONNX_OPERATOR_KERNEL_EX( - GemmFloat8, - kMSDomain, - 1, - kRocmExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) - .TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) - .TypeConstraint("TR", BuildKernelDefConstraints()) - .TypeConstraint("TS", BuildKernelDefConstraints()), - GemmFloat8); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh deleted file mode 100644 index b545eb1f2a149..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#if defined(USE_COMPOSABLE_KERNEL) - -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/utility/functional3.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif - -#if !defined(DISABLE_FLOAT8_TYPES) -#include "core/common/float8.h" -#endif -#include "core/providers/rocm/tunable/gemm_common.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -constexpr bool always_false = false; - -template -struct Scale { - constexpr const static bool is_pack2_invocable = true; - constexpr const static bool is_pack4_invocable = true; - - explicit Scale(float scale_value, const float* dev_scale_ptr) : scale_value_{scale_value}, dev_scale_ptr_{dev_scale_ptr} {} - - template - __forceinline__ __host__ __device__ Y fast_type_convert(X x) const { - static_assert(always_false, "not implemented"); - (void)x; - } - - template <> - __forceinline__ __host__ __device__ ck::half_t fast_type_convert(ck::f8_t x) const { - // https://github.com/ROCmSoftwarePlatform/triton/blob/0cc3f8b84a16892396f6e08a04991034d67e32b1/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L220-L233 - constexpr const uint16_t mask = 0x7fff; - constexpr const uint16_t sign_mask = 0x8000; - constexpr const uint16_t exp_compensate = []() { - if constexpr (std::is_same_v) { - return 0x2000; - } else if constexpr (std::is_same_v) { - return 0x1c00; - } - }(); - - uint8_t x_u8 = reinterpret_cast(x); - uint16_t x_u16 = static_cast(x_u8) << 8; - uint16_t exp = (x_u16 & mask) >> 1; - uint16_t y = (x_u16 & sign_mask) | (exp + exp_compensate); - return reinterpret_cast(y); - } - - __forceinline__ __host__ __device__ void operator()(ck::half_t& y, const ck::f8_t& x) const { - float scale = scale_value_ * (*dev_scale_ptr_); - y = ck::type_convert(scale * fast_type_convert(x)); - } - - __forceinline__ __host__ __device__ void operator()(ck::half2_t& ys, const ck::f8x2_t& xs) const { - float scale = scale_value_ * (*dev_scale_ptr_); - constexpr const uint32_t mask = 0x7fff7fff; - constexpr const uint32_t sign_mask = 0x80008000; - constexpr const uint32_t exp_compensate = []() { - if constexpr (std::is_same_v) { - return 0x20002000; - } else if constexpr (std::is_same_v) { - return 0x1c001c00; - } - }(); - - const uchar2& x2_u8 = reinterpret_cast(xs); - uchar4 x{0, x2_u8.x, 0, x2_u8.y}; - uint32_t x_u32 = reinterpret_cast(x); - - uint32_t exp = (x_u32 & mask) >> 1; - uint32_t v = (x_u32 & sign_mask) | (exp + exp_compensate); - ys = scale * reinterpret_cast(v); - } - - __forceinline__ __host__ __device__ void operator()(ck::half4_t& ys, const ck::f8x4_t& xs) const { - float scale = scale_value_ * (*dev_scale_ptr_); - constexpr const uint32_t mask = 0x7fff7fff; - constexpr const uint32_t sign_mask = 0x80008000; - constexpr const uint32_t exp_compensate = []() { - if constexpr (std::is_same_v) { - return 0x20002000; - } else if constexpr (std::is_same_v) { - return 0x1c001c00; - } - }(); - - uint32_t xs_u32 = reinterpret_cast(xs); - uint32_t x_u32_0 = __byte_perm(xs_u32, 0, 0x1504); - uint32_t x_u32_1 = __byte_perm(xs_u32, 0, 0x3726); - uint32_t exp_0 = (x_u32_0 & mask) >> 1; - uint32_t exp_1 = (x_u32_1 & mask) >> 1; - uint32_t v_0 = (x_u32_0 & sign_mask) | (exp_0 + exp_compensate); - uint32_t v_1 = (x_u32_1 & sign_mask) | (exp_1 + exp_compensate); - uint64_t v = v_0 | uint64_t(v_1) << 32; - ys = scale * reinterpret_cast(v); - } - - float scale_value_; - const float* const dev_scale_ptr_; -}; -#endif - -namespace blas { - -template -struct GemmFloat8Params : tunable::OpParams { - std::string Signature() const override { - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); - } - - hipblasHandle_t handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - float scale_a{}; - const float* scale_a_dev{}; - const TA* a; - int64_t lda; - float scale_b{}; - const float* scale_b_dev{}; - const TB* b; - int64_t ldb; - TC* c; - float scale_c{}; - const float* scale_c_dev{}; - int64_t ldc; -}; - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using Nop = ck::tensor_operation::element_wise::PassThrough; - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, Nop, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, Nop, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, Nop>>>& instances); - -template -auto CreateOp(float scale, const float* dev_scale) { - if constexpr (std::is_same_v) { - return Scale(scale, dev_scale); - } else if constexpr (std::is_same_v) { - return Scale(scale, dev_scale); - } else { - return Nop{}; - } -} - -template -auto GetCKF8SplitKGemmTypeStringAndOps() { - using CKTA = typename CKDataTypeAdaptor::type; - using CKTB = typename CKDataTypeAdaptor::type; - using CKTC = typename CKDataTypeAdaptor::type; - - using CKLayoutA = typename CKBlasOpAdaptor::type; - using CKLayoutB = typename CKBlasOpAdaptor::type; - - using OpA = std::conditional_t, Scale, Nop>; - using OpB = std::conditional_t, Scale, Nop>; - using OpC = std::conditional_t, Scale, Nop>; - - using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK< - CKLayoutA, CKLayoutB, Row, - CKTA, CKTB, CKTC, - OpA, OpB, OpC>; - - std::vector>>> ret; - - for (auto num_split : {1, 4, 16, 64}) { - std::vector> instances{}; - if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { - add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(instances); - } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { - add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(instances); - } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { - add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(instances); - } else { - static_assert(always_false, "no instances for the type combination"); - LOGS_DEFAULT(FATAL) << "no instances for the type combination"; - } - for (auto&& impl : instances) { - auto type_string = std::to_string(ret.size()) + "_" + impl->GetTypeString() + "_SplitK" + std::to_string(num_split); - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemm_op = [num_split, impl = std::move(impl), invoker = std::move(invoker)](const GemmFloat8Params* params) -> Status { - OpA op_a = CreateOp(params->scale_a, params->scale_a_dev); - OpB op_b = CreateOp(params->scale_b, params->scale_b_dev); - OpC op_c = CreateOp(params->scale_c, params->scale_c_dev); - - auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, params->ldc, - op_a, op_b, op_c, num_split); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); - } - } - return ret; -} - -#endif // USE_COMPOSABLE_KERNEL - -template -class GemmFloat8TunableOp : public TunableOp> { - public: - GemmFloat8TunableOp() { -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#else - ORT_ENFORCE(false, "CK is required to support GemmFloat8 computing"); -#endif // USE_COMPOSABLE_KERNEL - } -}; - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu deleted file mode 100644 index 4c691dd18f2e9..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { - -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -namespace internal { -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances); -} // namespace internal - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); -} - -namespace internal { -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances); - -// TODO: The first try of derivation does not going well due to various constraints. -// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( -// std::vector, PassThrough, PassThrough>>>& instances); - -// TODO: The first try of derivation does not going well due to various constraints. -// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( -// std::vector, PassThrough, PassThrough>>>& instances); -} // namespace internal - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, PassThrough, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); - // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: -} - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, PassThrough, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); - // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: -} - -namespace internal { -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances); -} // namespace internal - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); -} - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu deleted file mode 100644 index 49463e58886f8..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu +++ /dev/null @@ -1,97 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> - // clang-format on - >; - -// The derived version is simply double BBlockTransferSrcScalarPerVector and adjust other values correspondingly -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 8, 4, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 8, 4, 32, 32, 3, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 8, 4, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 12, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 16, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 8, 4, 32, 32, 3, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 8, 4, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu deleted file mode 100644 index 236e5555051fc..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu +++ /dev/null @@ -1,80 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu deleted file mode 100644 index 1a0d45df82a71..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu +++ /dev/null @@ -1,94 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 2, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F16> - // clang-format on - >; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu deleted file mode 100644 index a0628802ec09e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu +++ /dev/null @@ -1,97 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -template -using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> - // clang-format on - >; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); -} - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc deleted file mode 100644 index 7dbb24463961e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ /dev/null @@ -1,347 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_common.h" - -using namespace onnxruntime::common; - -namespace onnxruntime { -namespace contrib { -namespace rocm { -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasAdd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RemovePadding); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RestorePadding); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Rfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Rfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Irfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Irfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Irfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMulConj); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmax); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskBiasDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, NGramRepeatBlock); - -// These ops were experimental ops in onnx domain which have been removed now. We add them here as -// contrib ops to maintain backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Affine); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Affine); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BeamSearch); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GreedySearch); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GroupNorm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, NhwcConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ImageScaler); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ImageScaler); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LongformerAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LongformerAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Sampling); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_float, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, double_double_double, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_MLFloat16, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_MLFloat16, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_float, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, BFloat16_float_BFloat16, LayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_float, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double_double_double, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Inverse); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MatMulNBits); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Trilu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedGelu); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QuantizeWithOrder); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DequantizeWithOrder); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedAttention); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GemmFloat8); - -#ifdef ENABLE_ATEN -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen); -#endif - -#ifdef ENABLE_TRAINING_OPS -// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or -// 2). this is needed by inference for other purpose. -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather); -#endif - -#ifdef ORT_USE_NCCL -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll); -#endif - -template <> -KernelCreateInfo BuildKernelCreateInfo() { - KernelCreateInfo info; - return info; -} - -// clang-format off -Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { - static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // These ops were experimental ops in onnx domain which have been removed now. We add them here as - // contrib ops to maintain backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - // TransposedMatMul is still here for backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - -#ifdef ENABLE_ATEN - BuildKernelCreateInfo, -#endif - -#ifdef ENABLE_TRAINING_OPS - // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or - // 2). this is needed by inference for other purpose. - BuildKernelCreateInfo, -#endif - -#ifdef ORT_USE_NCCL - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, -#endif - - }; - - for (auto& function_table_entry : function_table) { - KernelCreateInfo info = function_table_entry(); - if (info.kernel_def != nullptr) { // filter disabled entries where type is void - ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); - } - } - - return Status::OK(); -} -// clang-format on - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h deleted file mode 100644 index db9a5d4fcd83e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -Status RegisterRocmContribKernels(KernelRegistry& kernel_registry); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index a5ab63d74df24..8929c6b7cf6e4 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -4,6 +4,7 @@ #include "contrib_ops/webgpu/bert/attention.h" #include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/flash_attention.h" #include "contrib_ops/webgpu/bert/multihead_attention.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -165,7 +166,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " let query_pos = m + local_id.y + past_sequence_length;\n" << " let key_pos = n + local_id.x;\n" << " if (key_pos > query_pos) {\n" - << " sum = -3.40282e+38; // Set to very negative value for masking\n" + << " sum = -3.4028234663852886e+38; // Set to very negative value for masking\n" << " }\n"; } @@ -272,7 +273,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let effective_seq_length = seq_causal_length;\n"; } shader.MainFunctionBody() - << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "var thread_max_vector = f32_val_t(-3.4028234663852886e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" << " let actual_pos = local_offset + i + start_offset;\n" << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n" @@ -289,7 +290,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } else if (use_smooth_softmax_) { shader.MainFunctionBody() << "var max_value: f32 = 0.0;\n"; } else { - shader.MainFunctionBody() << "var max_value = f32(-3.402823e+38f);\n"; + shader.MainFunctionBody() << "var max_value = f32(-3.4028234663852886e+38f);\n"; } shader.MainFunctionBody() << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" @@ -736,6 +737,19 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) // Compute Q, K, V from input, weights, and bias ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q, &K, &V)); + // Check if we can use flash attention + // For Attention operator, we need to create present_key and present_value tensors for flash attention + // even though they are not exposed as outputs + TensorShapeVector present_kv_shape({parameters.batch_size_, parameters.num_heads_, + parameters.total_sequence_length_, parameters.head_size_}); + Tensor present_key = context.CreateGPUTensor(input->DataType(), present_kv_shape); + Tensor present_value = context.CreateGPUTensor(input->DataType(), present_kv_shape); + + if (CanApplyFlashAttention(nullptr, &present_key, &present_value, parameters, context)) { + return ApplyFlashAttention(&Q, &K, &V, attention_bias, output, nullptr, &present_key, nullptr, &present_value, + parameters, context, nullptr); + } + // Apply the actual attention computation return ApplyAttention(&Q, &K, &V, attention_bias, nullptr, nullptr, output, /* present_key */ nullptr, /* present_value */ nullptr, /* output_qk */ nullptr, parameters, context, nullptr, nullptr, -1); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 606dbfde15c2c..47a223f1bed28 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -76,7 +76,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { } else { shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n"; } - shader.MainFunctionBody() << "let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n"; + shader.MainFunctionBody() << " let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n"; + if (past_present_share_buffer_) { + shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n"; + } else { + shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n"; + } // Add indirect dispatch logic for thread 0 if (prepare_indirect_dispatch_) { @@ -93,8 +98,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_past_) { const auto& past_key = shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); shader.AddInput("past_value", ShaderUsage::UseUniform); - shader.MainFunctionBody() << "let present_offset = global_idx;" - << "if (sequence_id < past_sequence_length) {\n" + shader.MainFunctionBody() << "if (sequence_id < past_sequence_length) {\n" << " let pastOffset = " << past_key.IndicesToOffset("past_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n" << " " << present_key.SetByOffset("present_offset", "past_key[pastOffset]") << ";\n" << " " << present_value.SetByOffset("present_offset", "past_value[pastOffset]") << ";\n" @@ -104,8 +108,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { << " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n" << "}"; } else { - shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n" - << " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n" + shader.MainFunctionBody() << " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n" << " " << present_key.SetByOffset("present_offset", "key[offset]") << ";\n" << " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n"; } @@ -134,10 +137,10 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt // Determine if we need to prepare indirect dispatch bool prepare_indirect_dispatch = (indirect_buffer != nullptr); bool use_seqlen_k = (seqlen_k != nullptr); - - CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH, + bool kv_BNSH = parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH || parameters.qkv_format_ == Q_K_V_BNSH; + CopyKVCacheProgram program{"CopyKVCache", has_past, kv_BNSH, parameters.past_present_share_buffer_, prepare_indirect_dispatch, use_seqlen_k}; - if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { + if (kv_BNSH) { program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); } else { @@ -207,6 +210,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_), WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), WGSL_TEMPLATE_PARAMETER(prefer_subgroupshuffle, !is_nvidia_), + WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_), WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_), WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_), WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_)); @@ -256,10 +260,20 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte {metadata, ProgramTensorMetadataDependency::Rank, 2}}); const uint32_t vectorized_head_size = parameters.head_size_ / components; + + // Get attention bias dimensions for broadcasting + uint32_t attn_bias_dim0 = 1; + uint32_t attn_bias_dim1 = 1; + if (has_attention_bias) { + const auto& bias_shape = attention_bias->Shape(); + attn_bias_dim0 = static_cast(bias_shape[0]); + attn_bias_dim1 = static_cast(bias_shape[1]); + } + if (use_indirect_dispatch) { program.SetIndirectDispatchTensor(indirect_buffer); } else { - program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile); + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_total_seq_length_tile); } program.SetWorkgroupSize(64) .CacheHint(tile_size, has_attention_bias, use_indirect_dispatch) @@ -269,7 +283,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte present_sequence_length, {static_cast(parameters.n_reps)}, {num_present_sequence_length_tile}, - {static_cast(parameters.num_heads_)}}); + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.batch_size_)}, + {attn_bias_dim0}, + {attn_bias_dim1}}); return context.RunProgram(program); } @@ -313,11 +330,12 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte {qk, ProgramTensorMetadataDependency::TypeAndRank}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size] + const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); if (use_indirect_dispatch) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}) .SetIndirectDispatchTensor(indirect_buffer); } else { - program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile); + program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile); } program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch) .SetWorkgroupSize(64) @@ -326,7 +344,7 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte present_sequence_length, {static_cast(parameters.n_reps)}, num_present_sequence_length_tile, - {static_cast(parameters.num_heads_)}}); + {batch_heads}}); return context.RunProgram(program); } @@ -363,14 +381,15 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}}); const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); - program.SetDispatchGroupSize(parameters.num_heads_ * num_head_size_tile) + const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); + program.SetDispatchGroupSize(batch_heads * num_head_size_tile) .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, num_present_sequence_length_tile, {num_head_size_tile}, - {static_cast(parameters.num_heads_)}}); + {batch_heads}}); return context.RunProgram(program); } @@ -421,7 +440,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co indirect_buffer_ptr, tile_size)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, indirect_buffer_ptr)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr)); } if (parameters.sequence_length_ > 1) { @@ -429,6 +448,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; FlashAttentionProgram program{"FlashAttention", has_attention_bias, is_qualcomm, @@ -437,6 +457,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co parameters.num_heads_, parameters.is_unidirectional_, is_nvidia, + q_BNSH, use_seqlen_k}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, @@ -451,15 +472,28 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; - program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile) + + // Get attention bias dimensions for broadcasting + uint32_t attn_bias_dim0 = 1; + uint32_t attn_bias_dim1 = 1; + if (has_attention_bias) { + const auto& bias_shape = attention_bias->Shape(); + attn_bias_dim0 = static_cast(bias_shape[0]); + attn_bias_dim1 = static_cast(bias_shape[1]); + } + + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile) .SetWorkgroupSize(tile_size) - .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, use_seqlen_k) + .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, q_BNSH, use_seqlen_k) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, + {static_cast(parameters.batch_size_)}, {static_cast(parameters.n_reps)}, {alpha}, - {num_seq_tile}}); + {num_seq_tile}, + {attn_bias_dim0}, + {attn_bias_dim1}}); return context.RunProgram(program); } @@ -500,8 +534,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value, const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { - return parameters.batch_size_ == 1 && - !parameters.is_packed_qkv_ && + return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && bias == nullptr && context.HasFeature(wgpu::FeatureName::Subgroups) && @@ -571,8 +604,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {head_size_vec}, - {half_rotary_embedding_dim_vec}, + {static_cast(head_size_vec)}, + {static_cast(half_rotary_embedding_dim_vec)}, {present_sequence_length}, {tile_size}, {static_cast(dispatch_size)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index a936a91695921..bb8c8de8c8ab9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -43,9 +43,9 @@ class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program { public: - CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, + CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer, bool prepare_indirect_dispatch = false, bool use_seqlen_k = false) - : Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), prepare_indirect_dispatch_(prepare_indirect_dispatch), use_seqlen_k_(use_seqlen_k) { + : Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer), prepare_indirect_dispatch_(prepare_indirect_dispatch), use_seqlen_k_(use_seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -59,6 +59,7 @@ class CopyKVCacheProgram final : public Program { private: bool has_past_; bool kv_BNSH_; + bool past_present_share_buffer_; bool prepare_indirect_dispatch_; bool use_seqlen_k_; }; @@ -73,6 +74,7 @@ class FlashAttentionProgram final : public Program { int qkv_num_heads, bool is_unidirectional, bool is_nvidia, + bool q_BNSH, bool use_seqlen_k = false) : Program{kernel_name}, has_attention_bias_(has_attention_bias), @@ -82,6 +84,7 @@ class FlashAttentionProgram final : public Program { qkv_num_heads_(qkv_num_heads), is_unidirectional_(is_unidirectional), is_nvidia_(is_nvidia), + q_BNSH_(q_BNSH), use_seqlen_k_(use_seqlen_k) { } @@ -90,9 +93,12 @@ class FlashAttentionProgram final : public Program { WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32}, {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"batch_size", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, {"alpha", ProgramUniformVariableDataType::Float32}, - {"num_seq_tile", ProgramUniformVariableDataType::Uint32}); + {"num_seq_tile", ProgramUniformVariableDataType::Uint32}, + {"attn_bias_dim0", ProgramUniformVariableDataType::Uint32}, + {"attn_bias_dim1", ProgramUniformVariableDataType::Uint32}); private: bool has_attention_bias_; @@ -102,6 +108,7 @@ class FlashAttentionProgram final : public Program { int qkv_num_heads_; bool is_unidirectional_; bool is_nvidia_; + bool q_BNSH_; bool use_seqlen_k_; }; @@ -120,7 +127,10 @@ class FlashAttentionDecodeQKTProgram final : public Program u32 { #if is_fp16 const min_value = q_element_t(-65504.0); #else -const min_value = q_element_t(-3.402823e+38f); +const min_value = q_element_t(-3.4028234663852886e+38f); #endif // For max performance max_k_step should be the same as sg_size, however we might run out of registers @@ -42,20 +43,27 @@ var v_tile : array, max_k_step>; // Private memory per lane. var q_tile : array; -fn loadq(q_idx_global : u32, head_idx : u32, alpha : q_element_t) { - // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA - // This is the layout if TransferBSDToBNSH has not been run. - let offset = q_idx_global * (head_size_vec)*num_heads + head_size_vec * head_idx; - // Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run. - // let offset = head_idx * uniforms.new_sequence_length * head_size_vec + q_idx_global * head_size_vec; +fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_t) { +#if q_BNSH + // Stored as BNSH - float16[batch_size, num_heads, sequence_length, head_size] + let offset = batch_idx * num_heads * uniforms.new_sequence_length * head_size_vec + + head_idx * uniforms.new_sequence_length * head_size_vec + + q_idx_global * head_size_vec; +#else + // Stored as BSNH - float16[batch_size, sequence_length, num_heads, head_size] + let offset = batch_idx * uniforms.new_sequence_length * head_size_vec * num_heads + + q_idx_global * head_size_vec * num_heads + + head_idx * head_size_vec; +#endif for (var idx : u32 = 0; idx < head_size_vec; idx++) { q_tile[idx] = q[idx + offset] * alpha; } } -fn loadk(k_start : u32, head_idx : u32, local_idx : u32, k_step : u32) { +fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,96] - let offset = head_idx * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; + let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + + k_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < get_total_sequence_length()); @@ -63,9 +71,10 @@ fn loadk(k_start : u32, head_idx : u32, local_idx : u32, k_step : u32) { } } -fn loadv(v_start : u32, head_idx : u32, local_idx : u32, v_step : u32) { +fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,96] - let offset = head_idx * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; + let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + + v_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * v_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < get_total_sequence_length()); @@ -83,9 +92,10 @@ var o_tile_r : array, workgroup_ // Private memory per lane. var o_tile : array; -fn writeo(o_idx_global : u32, head_idx : u32, local_idx : u32) { +fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32, local_idx : u32) { // Stored as float16[batch_size,sequence_length,3072] - let offset = o_idx_global * num_heads * head_size_vec + head_idx * head_size_vec; + let offset = batch_idx * uniforms.new_sequence_length * num_heads * head_size_vec + + o_idx_global * num_heads * head_size_vec + head_idx * head_size_vec; for (var idx : u32 = 0; idx < half_head_size_vec; idx++) { output[offset + idx] = o_tile[idx]; output[offset + idx + half_head_size_vec] = o_tile_r[local_idx][idx]; @@ -94,9 +104,10 @@ fn writeo(o_idx_global : u32, head_idx : u32, local_idx : u32) { #else // Private memory per lane. var o_tile : array; -fn writeo(o_idx_global : u32, head_idx : u32) { +fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) { // Stored as float16[batch_size,sequence_length,3072] - let offset = o_idx_global * num_heads * head_size_vec + head_idx * head_size_vec; + let offset = batch_idx * uniforms.new_sequence_length * num_heads * head_size_vec + + o_idx_global * num_heads * head_size_vec + head_idx * head_size_vec; for (var idx : u32 = 0; idx < head_size_vec; idx++) { output[offset + idx] = o_tile[idx]; } @@ -104,12 +115,17 @@ fn writeo(o_idx_global : u32, head_idx : u32) { #endif #if has_attention_bias -fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= get_total_sequence_length()) { + if (k_idx_global >= get_total_sequence_length()) { return vec4(0); } - let offset_base = head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); + // Handle broadcasting: if dimension size is 1, use index 0 + let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); + let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); + + let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() + + bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); let offset = offset_base + k_idx_global; let offset_max = offset_base + get_total_sequence_length(); let c1 = q_element_t(attention_bias[min(offset, offset_max)]); @@ -119,7 +135,7 @@ fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> return vec4(c1, c2, c3, c4); } #else -fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { return vec4(0); } #endif @@ -141,15 +157,21 @@ fn fetchVTile(k_idx: u32, vec_idx: u32, v_val: q_value_t) -> q_value_t { } $MAIN { - let head_idx = u32(workgroup_idx / uniforms.num_seq_tile); + let batch_head_idx = u32(workgroup_idx / uniforms.num_seq_tile); + let head_idx = batch_head_idx % num_heads; + let batch_idx = batch_head_idx / num_heads; let capped_sg_id = min(sg_id, max_k_step - 1u); let capped_sg_size = min(sg_size, max_k_step); + if (batch_idx >= uniforms.batch_size) { + return; + } + // Load Q let q_idx_global = (workgroup_idx % uniforms.num_seq_tile) * workgroup_size_x + local_idx; let valid_q = q_idx_global < uniforms.new_sequence_length; if (valid_q) { - loadq(q_idx_global, head_idx, q_element_t(uniforms.alpha)); + loadq(batch_idx, q_idx_global, head_idx, q_element_t(uniforms.alpha)); } var previous_max : q_element_t = min_value; @@ -170,8 +192,8 @@ $MAIN { for (var k_start = 0u; k_start < loop_bound; k_start += capped_sg_size) { workgroupBarrier(); - loadk(k_start, head_idx / uniforms.n_reps, local_idx, capped_sg_size); - loadv(k_start, head_idx / uniforms.n_reps, local_idx, capped_sg_size); + loadk(k_start, batch_head_idx, local_idx, capped_sg_size); + loadv(k_start, batch_head_idx, local_idx, capped_sg_size); workgroupBarrier(); // Compute QKt @@ -229,11 +251,11 @@ $MAIN { qk_2[3] += dot(q_own, fetchKTile(7, i, k_local)); } } - qk_1 = qk_1 + loadAttentionBias(q_idx_global, k_start, head_idx); - qk_2 = qk_2 + loadAttentionBias(q_idx_global, k_start + 4, head_idx); + qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx); + qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx); if (sg_size > 8) { - qk_3 = qk_3 + loadAttentionBias(q_idx_global, k_start + 8, head_idx); - qk_4 = qk_4 + loadAttentionBias(q_idx_global, k_start + 12, head_idx); + qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx); + qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx); } // Neuter qk values where K is out of bounds. @@ -360,7 +382,7 @@ $MAIN { } if (valid_q) { - writeo(q_idx_global, head_idx, local_idx); + writeo(batch_idx, q_idx_global, head_idx, local_idx); } #else if (sg_size > 8) { @@ -409,7 +431,7 @@ $MAIN { } if (valid_q) { - writeo(q_idx_global, head_idx); + writeo(batch_idx, q_idx_global, head_idx); } #endif } // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template index c6f768beffa0f..e7944231f342e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template @@ -35,12 +35,22 @@ var inner_qk_values: array, tile_ var tile_qk: array; #if has_attention_bias - fn loadAttentionBias(idx: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t { - return attention_bias[idx]; + // Handle broadcasting: if dimension size is 1, use index 0 + let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); + let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); + + // Calculate flat offset with broadcasting applied + // attention_bias shape: [attn_bias_dim0, attn_bias_dim1, new_seq_length, total_seq_length] + // For decode, new_seq_length is 1, so we can simplify: + let offset = bias_batch_idx * uniforms.attn_bias_dim1 * total_seq_length + + bias_head_idx * total_seq_length + + k_idx; + return attention_bias[offset]; } #else - fn loadAttentionBias(idx: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t { return q_element_t(0); } @@ -56,9 +66,14 @@ $MAIN { #endif let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; - let head_idx = u32(workgroup_idx / num_total_seq_length_tile); - let q_offset = head_idx * uniforms.head_size_vec; - let present_offset = u32(head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.head_size_vec; + let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile); + let head_idx = batch_head_idx % uniforms.num_heads; + let batch_idx = batch_head_idx / uniforms.num_heads; + if (batch_idx >= uniforms.batch_size) { + return; + } + let q_offset = batch_idx * uniforms.num_heads * uniforms.head_size_vec + head_idx * uniforms.head_size_vec; + let present_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) { tile_q[local_idx] = q[q_offset + k + local_idx]; @@ -75,25 +90,21 @@ $MAIN { workgroupBarrier(); } - if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length && head_idx < uniforms.num_heads) { + if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { var sum = q_element_t(0); for (var i = 0u; i < tile_size_k_vec; i++) { sum += inner_qk_values[local_idx][i]; } - sum = sum + loadAttentionBias(head_idx * total_sequence_length + total_seq_offset + local_idx); + sum = sum + loadAttentionBias(batch_idx, head_idx, 0u, total_seq_offset + local_idx, total_sequence_length); tile_qk[local_idx] = sum; - output[head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum; + output[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum; } workgroupBarrier(); - if (head_idx >= uniforms.num_heads) { - return; - } - if (local_idx == 0u) { // Calculate the max and sum in current split. - var l_max = f32(-3.402823e+38f); + var l_max = f32(-3.4028234663852886e+38f); var l_sum = f32(0); for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { l_max = max(l_max, f32(tile_qk[i])); @@ -101,7 +112,7 @@ $MAIN { for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { l_sum += exp(f32(tile_qk[i]) - l_max); } - let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; + let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; metadata[meta_offset] = metadata_value_t(l_max, l_sum); } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template index 37cf7e8f11b1f..8139477172b03 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template @@ -48,30 +48,31 @@ $MAIN { #endif let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; - let head_idx = u32(workgroup_idx / num_total_seq_length_tile); - let present_offset = u32(head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length; + let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile); + if (batch_head_idx >= uniforms.batch_heads) { + return; + } + let present_offset = u32(batch_head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length; // Calculate the global max and sum in qk. - if (head_idx < uniforms.num_heads) + var g_max = f32(-3.4028234663852886e+38f); + var g_sum = f32(0); + for (var i = 0u; i < num_total_seq_length_tile; i++) { - var g_max = f32(-3.402823e+38f); - var g_sum = f32(0); - for (var i = 0u; i < num_total_seq_length_tile; i++) - { - let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i; - g_max = max(g_max, metadata[meta_offset].x); - } - for (var i = 0u; i < num_total_seq_length_tile; i++) - { - let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i; - let m_value = metadata[meta_offset]; - g_sum += exp(m_value.x - g_max) * m_value.y; - } + let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i; + g_max = max(g_max, metadata[meta_offset].x); + } + for (var i = 0u; i < num_total_seq_length_tile; i++) + { + let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i; + let m_value = metadata[meta_offset]; + g_sum += exp(m_value.x - g_max) * m_value.y; + } if (total_seq_offset + local_idx < total_sequence_length) { - tile_qk[local_idx] = present_value_element_t(exp(f32(qk[head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum); + tile_qk[local_idx] = present_value_element_t(exp(f32(qk[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum); } - } + for (var k: u32 = 0u; k < head_size_vec; k += tile_size_k_vec) { var value = present_value_value_t(0); qkv_values[local_row][local_col] = present_value_value_t(0); @@ -96,12 +97,8 @@ $MAIN { workgroupBarrier(); } - if (head_idx >= uniforms.num_heads) { - return; - } - for (var i = local_idx; i < head_size_vec; i += workgroup_size_x) { - let out_offset = head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i; + let out_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i; out_split_vx[out_offset] = tile_output[i]; } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index 22f18655307de..f909a87724da6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -20,8 +20,11 @@ var tile_input: array, tile_size>; $MAIN { let head_size_offset = (workgroup_idx % uniforms.num_head_size_tile) * tile_size; - let head_idx = u32(workgroup_idx / uniforms.num_head_size_tile); - let in_offset = head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; + let batch_head_idx = u32(workgroup_idx / uniforms.num_head_size_tile); + if (batch_head_idx >= uniforms.batch_heads) { + return; + } + let in_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; var value = output_value_t(0); let local_row = u32(local_idx / tile_size); let local_col = local_idx % tile_size; @@ -43,16 +46,12 @@ $MAIN { tile_input[local_row][local_col] = value; workgroupBarrier(); - if (head_idx >= uniforms.num_heads) { - return; - } - if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) { value = output_value_t(0); for (var i = 0u; i < tile_size; i++) { value += tile_input[i][local_idx]; } - let output_id = head_idx * uniforms.head_size_vec + head_size_offset + local_idx; + let output_id = batch_head_idx * uniforms.head_size_vec + head_size_offset + local_idx; output[output_id] = value; } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 05717fd2fe686..7ca61008be83f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -128,8 +128,8 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {head_size_vec}, - {half_rotary_embedding_dim_vec}, + {static_cast(head_size_vec)}, + {static_cast(half_rotary_embedding_dim_vec)}, {static_cast(dispatch_size)}, }) .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); @@ -287,7 +287,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking WebgpuAttentionParameters temp_params = parameters; temp_params.is_packed_qkv_ = false; - will_use_flash_attention = CanApplyFlashAttention(attention_bias, present_key, present_value, temp_params, context); + will_use_flash_attention = CanApplyFlashAttention(nullptr, present_key, present_value, temp_params, context); } if (parameters.is_packed_qkv_ && do_rotary_) { diff --git a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template index 1214777009a8d..6e0d4c7299793 100644 --- a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template @@ -18,7 +18,7 @@ const K: u32 = k; #if is_fp16 const MAX_FLOAT: f16 = 65504.0; #else -const MAX_FLOAT: f32 = 3.402823466e+38; +const MAX_FLOAT: f32 = 3.4028234663852886e+38; #endif var shared_vals: array; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index e77496b6e8196..1c80d83f99feb 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -499,8 +499,7 @@ class PlannerImpl { /*! \brief Given a tensor-type, return the size of an element of the tensor. */ static size_t GetElementSize(const DataType& tensor_type) { - const TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); - MLDataType ml_data_type = DataTypeImpl::TypeFromProto(type_proto); + MLDataType ml_data_type = DataTypeImpl::GetDataType(*tensor_type); const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); ORT_ENFORCE(nullptr != tensor_type_base); MLDataType elt_type = tensor_type_base->GetElementType(); diff --git a/onnxruntime/core/framework/ort_value_name_idx_map.h b/onnxruntime/core/framework/ort_value_name_idx_map.h index 76e7e369514d4..6035dc4e85242 100644 --- a/onnxruntime/core/framework/ort_value_name_idx_map.h +++ b/onnxruntime/core/framework/ort_value_name_idx_map.h @@ -33,7 +33,7 @@ class OrtValueNameIdxMap { common::Status GetIdx(std::string_view name, int& idx) const { idx = -1; - auto it = map_.find(std::string(name)); + auto it = map_.find(name); if (it == map_.end()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not find OrtValue with name '", name, "'"); } diff --git a/onnxruntime/core/framework/plugin_ep_stream.cc b/onnxruntime/core/framework/plugin_ep_stream.cc index 1eb6ad4162f33..becaefa041ce5 100644 --- a/onnxruntime/core/framework/plugin_ep_stream.cc +++ b/onnxruntime/core/framework/plugin_ep_stream.cc @@ -3,10 +3,40 @@ #include "core/framework/plugin_ep_stream.h" #include "core/framework/error_code_helper.h" +#include "core/session/abi_logger.h" + +using namespace ::onnxruntime::logging; + +#define LOG_AND_RETURN_IF_ORT_ERROR(fn, logger) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + LOGS(logger, ERROR) << "Plug-in EP Error: [" << OrtApis::GetErrorCode(_status) << "] " \ + << OrtApis::GetErrorMessage(_status); \ + OrtApis::ReleaseStatus(_status); \ + return; \ + } \ + } while (0) namespace onnxruntime { namespace plugin_ep { +void Notification::WaitNotificationOnDevice(onnxruntime::Stream* stream, synchronize::Notification& notification) { + auto* this_ptr = static_cast(¬ification); + + LOG_AND_RETURN_IF_ORT_ERROR(this_ptr->impl_.WaitOnDevice(&this_ptr->impl_, static_cast(stream)), + *this_ptr->logger_.ToInternal()); +} +void Notification::WaitNotificationOnHost(onnxruntime::Stream*, synchronize::Notification& notification) { + auto* this_ptr = static_cast(¬ification); + LOG_AND_RETURN_IF_ORT_ERROR(this_ptr->impl_.WaitOnHost(&this_ptr->impl_), *this_ptr->logger_.ToInternal()); +} +void Notification::Activate() { + LOG_AND_RETURN_IF_ORT_ERROR(impl_.Activate(&impl_), *logger_.ToInternal()); +} +void Stream::Flush() { + LOG_AND_RETURN_IF_ORT_ERROR(impl_.Flush(&impl_), *logger_.ToInternal()); +} // TODO: Is num_consumers meaningful? Unused everywhere currently. OrtStatus* Stream::CreateNotificationImpl(size_t /*num_consumers*/, std::unique_ptr& result) { OrtSyncNotificationImpl* notification_impl = nullptr; diff --git a/onnxruntime/core/framework/plugin_ep_stream.h b/onnxruntime/core/framework/plugin_ep_stream.h index 09938403ad9b5..6080c17f8f36a 100644 --- a/onnxruntime/core/framework/plugin_ep_stream.h +++ b/onnxruntime/core/framework/plugin_ep_stream.h @@ -2,7 +2,6 @@ // Licensed under the MIT License. #pragma once -#include "core/common/logging/logging.h" #include "core/framework/stream_handles.h" #include "core/framework/error_code_helper.h" #include "core/session/onnxruntime_c_api.h" @@ -13,43 +12,20 @@ struct OrtSyncStream : public onnxruntime::Stream {}; struct OrtSyncNotification : onnxruntime::synchronize::Notification {}; -using onnxruntime::logging::Logger; - -#define LOG_AND_RETURN_IF_ORT_ERROR(fn, logger) \ - do { \ - OrtStatus* _status = (fn); \ - if (_status != nullptr) { \ - LOGS(logger, ERROR) << "Plug-in EP Error: [" << OrtApis::GetErrorCode(_status) << "] " \ - << OrtApis::GetErrorMessage(_status); \ - OrtApis::ReleaseStatus(_status); \ - return; \ - } \ - } while (0) - namespace onnxruntime { namespace plugin_ep { class Notification : public synchronize::Notification { public: - Notification(Stream& stream, OrtSyncNotificationImpl& impl, const Logger& logger) + Notification(Stream& stream, OrtSyncNotificationImpl& impl, const OrtLogger& logger) : synchronize::Notification(stream), impl_{impl}, logger_{logger} { } - static void WaitNotificationOnDevice(onnxruntime::Stream* stream, synchronize::Notification& notification) { - auto* this_ptr = static_cast(¬ification); + static void WaitNotificationOnDevice(onnxruntime::Stream* stream, synchronize::Notification& notification); - LOG_AND_RETURN_IF_ORT_ERROR(this_ptr->impl_.WaitOnDevice(&this_ptr->impl_, static_cast(stream)), - this_ptr->logger_); - } + static void WaitNotificationOnHost(onnxruntime::Stream* /*stream*/, synchronize::Notification& notification); - static void WaitNotificationOnHost(onnxruntime::Stream* /*stream*/, synchronize::Notification& notification) { - auto* this_ptr = static_cast(¬ification); - LOG_AND_RETURN_IF_ORT_ERROR(this_ptr->impl_.WaitOnHost(&this_ptr->impl_), this_ptr->logger_); - } - - void Activate() override { - LOG_AND_RETURN_IF_ORT_ERROR(impl_.Activate(&impl_), logger_); - } + void Activate() override; ~Notification() override { impl_.Release(&impl_); @@ -57,12 +33,12 @@ class Notification : public synchronize::Notification { private: OrtSyncNotificationImpl& impl_; - const Logger& logger_; + const OrtLogger& logger_; }; class Stream : public onnxruntime::Stream { public: - Stream(const OrtDevice& memory_device, OrtSyncStreamImpl& impl, const logging::Logger& logger) + Stream(const OrtDevice& memory_device, OrtSyncStreamImpl& impl, const OrtLogger& logger) : onnxruntime::Stream(impl.GetHandle(&impl), memory_device), impl_{impl}, logger_{logger} { } @@ -78,9 +54,7 @@ class Stream : public onnxruntime::Stream { return plugin_notification; } - void Flush() override { - LOG_AND_RETURN_IF_ORT_ERROR(impl_.Flush(&impl_), logger_); - } + void Flush() override; Status CleanUpOnRunEnd() override { auto* ort_status = impl_.OnSessionRunEnd(&impl_); @@ -103,7 +77,7 @@ class Stream : public onnxruntime::Stream { OrtStatus* CreateNotificationImpl(size_t num_consumers, std::unique_ptr& result); OrtSyncStreamImpl& impl_; - const Logger& logger_; + const OrtLogger& logger_; }; } // namespace plugin_ep } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 8fb3dc63aa4d1..a14e219d9c039 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1179,7 +1179,7 @@ Status SessionState::CreateSubgraphSessionState() { const auto& ep = node.GetExecutionProviderType(); if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && - ep != kRocmExecutionProvider && ep != kDmlExecutionProvider && + ep != kDmlExecutionProvider && ep != kJsExecutionProvider && ep != kWebGpuExecutionProvider) { // SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow // node containing the subgraph it will create whatever state it needs internally. diff --git a/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h b/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h index bc52a45adfd43..94ef87fb069af 100644 --- a/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h +++ b/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h @@ -83,7 +83,8 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext { const int rank = nchw_shape.dim_size(); // N and C dims are required. Some operators like AveragePool allow 1D input if (rank < 3) { - fail_shape_inference("Output tensor must have at least 3 dimensions"); + *nhwc_tp.mutable_tensor_type()->mutable_shape() = nchw_shape; + return; } // Convert output shape from N, C, H {, W, ...} to N, H {, W, ...}, C. @@ -105,8 +106,8 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext { const int rank = nhwc_shape.dim_size(); // N and C dims are required. Some operators like AveragePool allow 1D input. if (rank < 3) { - fail_shape_inference( - "Tensor must have at least 3 dimensions to convert between channels first and channels last."); + *nchw_tp.mutable_tensor_type()->mutable_shape() = nhwc_shape; + return; } // Convert input shape from {N, H, W, ..., C} to {N, C, H, W, ...}. diff --git a/onnxruntime/core/graph/data_propagation/add_op_data_propagation.cc b/onnxruntime/core/graph/data_propagation/add_op_data_propagation.cc new file mode 100644 index 0000000000000..172941c0ee023 --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/add_op_data_propagation.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "add_op_data_propagation.h" +#include "core/common/common.h" +#include "core/graph/node_arg.h" +#include "core/graph/onnx_protobuf.h" +#include "core/providers/common.h" + +namespace onnxruntime { + +Status AddOpDataPropagation::infer() { + // Get "A" input + const auto* input_0 = node_.InputDefs()[0]; + // Get "B" input + const auto* input_1 = node_.InputDefs()[1]; + + // Return and do nothing if input doesn't exist + if (!input_0 || !input_1 || !input_0->Exists() || !input_1->Exists()) { + return Status::OK(); + } + + if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) { + output_def_.SetInferredShapeScalarValue( + input_0->GetInferredShapeScalarValue().value() + + input_1->GetInferredShapeScalarValue().value()); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/add_op_data_propagation.h b/onnxruntime/core/graph/data_propagation/add_op_data_propagation.h new file mode 100644 index 0000000000000..f9eb9990142c1 --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/add_op_data_propagation.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "custom_data_propagation.h" +#include "core/graph/graph.h" + +namespace onnxruntime { + +/** + * @brief Class to infer the output scalar for 'Add' operator given the input is a scalar related to shape. + * + * + * For example: + * + * (input with the shape as float32[1, 3, 64, 64]) + * | + * v + * Shape (It saves [1, 3, 64, 64] in inferred_shape_values_ in output's node_arg + * | during Graph::SaveShapeValuesFromDataPropagation()) + * | + * | ______ + * | | + * v v + * Gather Gather (First 'Gather' saves 3 in inferred_scalar_value_ in output node_arg, and + * | | second 'Gather' saves 64 in inferred_scalar_value_ in output node_arg + * | | during GatherOpDataPropagation(), if the 'index' attributes + * | | are 1 and 2 respectively) + * \ / + * \ / + * | | + * v v + * Add (It gets 3 from inferred_scalar_value_ in input A's node_arg and 64 from inferred_scalar_value_ + * | in input B's node_arg, then performs add operation to get 67 and saves in inferred_scalar_value_ + * | in output's node_arg) + * v + * ... + */ +class AddOpDataPropagation : public CustomDataPropagationBase { + public: + AddOpDataPropagation(const Node& node, + NodeArg& output_def, + std::function func, + const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation, + const logging::Logger& logger) noexcept + : CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {} + + Status infer() override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/custom_data_propagation.cc b/onnxruntime/core/graph/data_propagation/custom_data_propagation.cc new file mode 100644 index 0000000000000..b7254aa828107 --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/custom_data_propagation.cc @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "custom_data_propagation.h" +#include "core/common/common.h" +#include "core/graph/graph.h" +#include "core/common/logging/logging.h" +#include "size_op_data_propagation.h" +#include "squeeze_op_data_propagation.h" +#include "unsqueeze_op_data_propagation.h" +#include "gather_op_data_propagation.h" +#include "add_op_data_propagation.h" +#include "sub_op_data_propagation.h" +#include "mul_op_data_propagation.h" +#include "div_op_data_propagation.h" +#include + +namespace onnxruntime { + +std::unique_ptr CreateCustomDataPropagation(const Node& node, + NodeArg& output_def, + std::function func, + const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation, + const logging::Logger& logger) { + int dim_size = 0; + if (output_from_onnx_op_data_propagation.has_tensor_type() && + output_from_onnx_op_data_propagation.tensor_type().has_shape()) { + dim_size = output_from_onnx_op_data_propagation.tensor_type().shape().dim_size(); + } + + if (node.OpType() == "Size") { + return std::make_unique(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger); + } else if (node.OpType() == "Squeeze") { + return std::make_unique(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger); + } else if (node.OpType() == "Unsqueeze") { + return std::make_unique(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger); + } else if (dim_size == 0) { + if (node.OpType() == "Gather") { + return std::make_unique(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger); + } else if (node.OpType() == "Add") { + return std::make_unique(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger); + } else if (node.OpType() == "Sub") { + return std::make_unique(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger); + } else if (node.OpType() == "Mul") { + return std::make_unique(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger); + } else if (node.OpType() == "Div") { + return std::make_unique(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger); + } + } + return nullptr; +} + +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/graph/data_propagation/custom_data_propagation.h b/onnxruntime/core/graph/data_propagation/custom_data_propagation.h new file mode 100644 index 0000000000000..7511f77f58519 --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/custom_data_propagation.h @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/graph/graph.h" +#include "core/common/logging/logging.h" +#include + +namespace onnxruntime { + +/** + * @class CustomDataPropagation + * Custom data propagation for the operator to help enhance shape inference. + * + * Calling infer() can infer the output values for the specific operator given the input is shape values + * and saves the output values in output node_arg for other operators to use later. + * The purpose of this class is to make shape values being correctly inferred and propogated through the graph. + */ +class CustomDataPropagationBase { + public: + ORT_DISALLOW_COPY(CustomDataPropagationBase); + virtual ~CustomDataPropagationBase() = default; + virtual Status infer() = 0; + + protected: + CustomDataPropagationBase(const Node& node, + NodeArg& output_def, + std::function func, + const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation, + const logging::Logger& logger) noexcept + : node_(node), + output_def_(output_def), + get_initialized_input_values_func_(std::move(func)), + output_from_onnx_op_data_propagation_(output_from_onnx_op_data_propagation), + logger_(logger) {} + + const Node& node_; + NodeArg& output_def_; + std::function get_initialized_input_values_func_; + const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation_; + const logging::Logger& logger_; +}; + +/** + * @brief Create custom data propagation for the operator. + * + * For certain operators (e.g., Size, Squeeze, Unsqueeze), ONNX's + * PartialDataPropagationFunction() does not always produce complete or accurate + * inferred shape values. + * + * In particular: + * - Scalar inputs and outputs are not handled correctly. + * - Some operators require additional logic that is not covered by the default function, + e.g. PartialDataPropagationFunction. + * + * Therefore, for these cases, we perform custom data propagation to ensure + * correct and complete inference. + * + * @param node The ORT's node + * @param output_def The node's output NodeArg to save the inferred shape values if needed + * @param func Helper function to get the input value if it's a initializer + * @param output_from_onnx_op_data_propagation The result from executing ONNX operator's data propagation + * @param logger The reference to a logger + * @return std::unique_ptr Returns a CustomDataPropagation object if available + */ +std::unique_ptr CreateCustomDataPropagation( + const Node& node, + NodeArg& output_def, + std::function func, + const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation, + const logging::Logger& logger); + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/div_op_data_propagation.cc b/onnxruntime/core/graph/data_propagation/div_op_data_propagation.cc new file mode 100644 index 0000000000000..2ea9b3047941c --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/div_op_data_propagation.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "div_op_data_propagation.h" +#include "core/common/common.h" +#include "core/graph/node_arg.h" +#include "core/graph/onnx_protobuf.h" +#include "core/providers/common.h" + +namespace onnxruntime { + +Status DivOpDataPropagation::infer() { + // Get "A" input + const auto* input_0 = node_.InputDefs()[0]; + // Get "B" input + const auto* input_1 = node_.InputDefs()[1]; + + // Return and do nothing if input doesn't exist + if (!input_0 || !input_1 || !input_0->Exists() || !input_1->Exists()) { + return Status::OK(); + } + + if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) { + output_def_.SetInferredShapeScalarValue( + input_0->GetInferredShapeScalarValue().value() / + input_1->GetInferredShapeScalarValue().value()); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/div_op_data_propagation.h b/onnxruntime/core/graph/data_propagation/div_op_data_propagation.h new file mode 100644 index 0000000000000..9b32b59039282 --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/div_op_data_propagation.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "custom_data_propagation.h" +#include "core/graph/graph.h" + +namespace onnxruntime { + +/** + * @brief Class to infer the output scalar for 'Div' operator given the input is a scalar related to shape. + * + * + * For example: + * + * (input with the shape as float32[1, 3, 64, 64]) + * | + * v + * Shape (It saves [1, 3, 64, 64] in inferred_shape_values_ in output's node_arg + * | during graph::SaveShapeValuesFromDataPropagation()) + * | + * | ______ + * | | + * v v + * Gather Gather (First 'Gather' saves 64 in inferred_scalar_value_ in output node_arg, and + * | | second 'Gather' saves 1 in inferred_scalar_value_ in output node_arg + * | | during GatherOpDataPropagation(), if the 'index' attributes + * | | are 2 and 0 respectively) + * \ / + * \ / + * | | + * v v + * Div (It gets 64 from inferred_scalar_value_ in input A's node_arg and 1 from inferred_scalar_value_ + * | in input B's node_arg, then performs div operation to get 64 and saves in inferred_scalar_value_ + * | in output's node_arg) + * v + * ... + */ +class DivOpDataPropagation : public CustomDataPropagationBase { + public: + DivOpDataPropagation(const Node& node, + NodeArg& output_def, + std::function func, + const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation, + const logging::Logger& logger) noexcept + : CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {} + + Status infer() override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/gather_op_data_propagation.cc b/onnxruntime/core/graph/data_propagation/gather_op_data_propagation.cc new file mode 100644 index 0000000000000..39ac926a8553f --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/gather_op_data_propagation.cc @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gather_op_data_propagation.h" +#include "core/common/common.h" +#include "core/graph/node_arg.h" +#include "core/graph/onnx_protobuf.h" +#include "core/providers/common.h" + +namespace onnxruntime { + +Status GatherOpDataPropagation::infer() { + if (output_from_onnx_op_data_propagation_.has_tensor_type() && + output_from_onnx_op_data_propagation_.tensor_type().has_shape()) { + int dim_size = output_from_onnx_op_data_propagation_.tensor_type().shape().dim_size(); + // Check there is no result from Gather's PartialDataPropagationFunction(), + // so that it can run custom data propagation below. + // Otherwise, this infer() function won't be called as the result from Gather's PartialDataPropagationFunction() + // will be used in Graph::SaveShapeValuesFromDataPropagation(). + if (dim_size != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "ORT shouldn't run Gather's custom data propagation here as Gather's" + "PartialDataPropagationFunction() already infers shape values in the output."); + } + } + + // Following code extracts an element from a 1D array if all conditions are met. + // e.g. + // shape data is [1, 3, 64, 64] -> gets 64 if the index is 2. + // shape data is [1, 3, 64, 64] -> gets 3 if the index is 1. + + // Get "data" input + // Note: The "data" input should be an one dimension array in this case. + const auto* input_0 = node_.InputDefs()[0]; + + // Get "indices" input + // Note: The "indices" input could be one of the three cases: + // 1. A tensor with rank > 0 and all tensor values are known. + // 2. A tensor with rank > 0 but not all tensor values are known. + // 3. A scalar. + // + // If it's case #1, ONNX operator's PartialDataPropagationFunction() + // should have inferred the output shape value. + // If it's case #2, neither ONNX operator's PartialDataPropagationFunction() + // nor Gather's custom data propagation can handle it. + // This Gather's custom data propagation handles case #3. + const auto* input_1 = node_.InputDefs()[1]; + + // Return and do nothing if input doesn't exist + if (!input_0 || !input_1 || !input_0->Exists() || !input_1->Exists()) { + return Status::OK(); + } + + // If input's inferred shape values is present, we then perfrom the gather operation on the shape values + // and saves the result in output's node_arg. + if (input_0->GetInferredShapeValues().has_value()) { + const auto& tensor_shape_proto = input_0->GetInferredShapeValues().value(); + + ORT_TRY { + TensorShapeVector indices; + ORT_RETURN_IF_ERROR(get_initialized_input_values_func_(input_1->Name(), indices)); + if (indices.size() == 1) { + // Note: Index value is expected to be within bounds [-s, s-1] along axis of size s + auto index = static_cast( + HandleNegativeAxis(indices[0], tensor_shape_proto.dim_size())); + + auto& dim = tensor_shape_proto.dim(index); + if (dim.has_dim_value()) { + output_def_.SetInferredShapeScalarValue(dim.dim_value()); + } + } + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + LOGS(logger_, ERROR) << ex.what(); + LOGS(logger_, INFO) << "Skip Gather op custom data propagation."; + }); + return Status::OK(); + } + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/gather_op_data_propagation.h b/onnxruntime/core/graph/data_propagation/gather_op_data_propagation.h new file mode 100644 index 0000000000000..c6b542e5af6e3 --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/gather_op_data_propagation.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "custom_data_propagation.h" +#include "core/graph/graph.h" + +namespace onnxruntime { + +/** + * @brief Class to infer the output scalar for 'Gather' operator given the input is shape values. + * + * + * For example: + * + * (input with the shape as float32[1, 3, 64, 64]) + * | + * v + * Shape (It saves [1, 3, 64, 64] in inferred_shape_values_ in output's node_arg + * | during graph::SaveShapeValuesFromDataPropagation()) + * | + * | ______ + * | | + * v v + * Gather Gather (First 'Gather' gets [1, 3, 64, 64] from input node_node's inferred_shape_values_, and + * | | then saves 3 in inferred_scalar_value_ in output node_args if 'index' attribute is 1. + * | | Same logic for second 'Gather', it saves 64 in inferred_scalar_value_ in output node_arga + * \ / if 'index" attribute is 2) + * \ / + * | | + * v v + * Mul (It gets 3 from inferred_scalar_value_ in input A's node_arg and 64 from inferred_scalar_value_ + * | in input B's node_arg, then performs mul operation to get 192 and saves in inferred_scalar_value_ + * | in output's node_arg) + * v + * ... + */ +class GatherOpDataPropagation : public CustomDataPropagationBase { + public: + GatherOpDataPropagation(const Node& node, + NodeArg& output_def, + std::function func, + const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation, + const logging::Logger& logger) noexcept + : CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {} + + Status infer() override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/mul_op_data_propagation.cc b/onnxruntime/core/graph/data_propagation/mul_op_data_propagation.cc new file mode 100644 index 0000000000000..4c5c25022e40b --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/mul_op_data_propagation.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mul_op_data_propagation.h" +#include "core/common/common.h" +#include "core/graph/node_arg.h" +#include "core/graph/onnx_protobuf.h" +#include "core/providers/common.h" + +namespace onnxruntime { + +Status MulOpDataPropagation::infer() { + // Get "A" input + const auto* input_0 = node_.InputDefs()[0]; + // Get "B" input + const auto* input_1 = node_.InputDefs()[1]; + + // Return and do nothing if input doesn't exist + if (!input_0 || !input_1 || !input_0->Exists() || !input_1->Exists()) { + return Status::OK(); + } + + if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) { + output_def_.SetInferredShapeScalarValue( + input_0->GetInferredShapeScalarValue().value() * + input_1->GetInferredShapeScalarValue().value()); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/mul_op_data_propagation.h b/onnxruntime/core/graph/data_propagation/mul_op_data_propagation.h new file mode 100644 index 0000000000000..b42a591eb4c7b --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/mul_op_data_propagation.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "custom_data_propagation.h" +#include "core/graph/graph.h" +namespace onnxruntime { + +/** + * @brief Class to infer the output scalar for 'Mul' operator given the input is a scalar related to shape. + * + * + * For example: + * + * (input with the shape as float32[1, 3, 64, 64]) + * | + * v + * Shape (It saves [1, 3, 64, 64] in inferred_shape_values_ in output's node_arg + * | during graph::SaveShapeValuesFromDataPropagation()) + * | + * | ______ + * | | + * v v + * Gather Gather (First 'Gather' saves 3 in inferred_scalar_value_ in output node_arg, and + * | | second 'Gather' saves 64 in inferred_scalar_value_ in output node_arg + * | | during GatherOpDataPropagation(), if the 'index' attributes + * | | are 1 and 2 respectively) + * \ / + * \ / + * | | + * v v + * Mul (It gets 3 from inferred_scalar_value_ in input A's node_arg and 64 from inferred_scalar_value_ + * | in input B's node_arg, then performs mul operation to get 192 and saves in inferred_scalar_value_ + * | in output's node_arg) + * v + * ... + */ +class MulOpDataPropagation : public CustomDataPropagationBase { + public: + MulOpDataPropagation(const Node& node, + NodeArg& output_def, + std::function func, + const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation, + const logging::Logger& logger) noexcept + : CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {} + + Status infer() override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/size_op_data_propagation.cc b/onnxruntime/core/graph/data_propagation/size_op_data_propagation.cc new file mode 100644 index 0000000000000..fe8ff3864296d --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/size_op_data_propagation.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "size_op_data_propagation.h" +#include "core/common/common.h" +#include "core/graph/node_arg.h" +#include "core/graph/onnx_protobuf.h" + +namespace onnxruntime { + +Status SizeOpDataPropagation::infer() { + // Size operator generates a scalar output + const auto* input_0 = node_.InputDefs()[0]; + + // Return and do nothing if input doesn't exist + if (!input_0 || !input_0->Exists()) { + return Status::OK(); + } + + if (input_0->GetInferredShapeValues().has_value()) { + const auto& tensor_shape_proto = input_0->GetInferredShapeValues().value(); + + int64_t num_elements = 1; + // The TensorShapeProto (inferred shape values) should have rank > 0 and + // all the dimensions have values (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { + for (const auto& dim : tensor_shape_proto.dim()) { + if (!dim.has_dim_value()) { + return Status::OK(); // Or handle the error appropriately + } + num_elements *= dim.dim_value(); + } + + output_def_.SetInferredShapeScalarValue(num_elements); + } + } + + return Status::OK(); +} + +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/graph/data_propagation/size_op_data_propagation.h b/onnxruntime/core/graph/data_propagation/size_op_data_propagation.h new file mode 100644 index 0000000000000..184202254f078 --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/size_op_data_propagation.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "custom_data_propagation.h" +#include "core/graph/graph.h" + +namespace onnxruntime { + +/** + * @brief Class to infer the output scalar for 'Size' operator given the input is shape values. + * + * 'Size' operator takes a tensor as input and outputs a int64 scalar that equals to the total + * number of elements of the input tensor. + */ +class SizeOpDataPropagation : public CustomDataPropagationBase { + public: + SizeOpDataPropagation(const Node& node, + NodeArg& output_def, + std::function func, + const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation, + const logging::Logger& logger) noexcept + : CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {} + + Status infer() override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/squeeze_op_data_propagation.cc b/onnxruntime/core/graph/data_propagation/squeeze_op_data_propagation.cc new file mode 100644 index 0000000000000..b33d4f5589016 --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/squeeze_op_data_propagation.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "squeeze_op_data_propagation.h" +#include "core/common/common.h" +#include "core/graph/node_arg.h" +#include "core/graph/onnx_protobuf.h" +#include "core/providers/common.h" +#include "core/common/inlined_containers.h" + +namespace onnxruntime { + +Status SqueezeOpDataPropagation::infer() { + const auto* input_0 = node_.InputDefs()[0]; + + // Return and do nothing if input doesn't exist + if (!input_0 || !input_0->Exists()) { + return Status::OK(); + } + + if (input_0->GetInferredShapeValues().has_value()) { + const auto& tensor_shape_proto = input_0->GetInferredShapeValues().value(); + + // The TensorShapeProto (inferred shape values) should have rank > 0 and + // all the dimensions have values (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { + for (const auto& dim : tensor_shape_proto.dim()) { + if (!dim.has_dim_value()) { + return Status::OK(); + } + } + } + + if (tensor_shape_proto.dim_size() == 1) { + output_def_.SetInferredShapeScalarValue(tensor_shape_proto.dim(0).dim_value()); + } else if (tensor_shape_proto.dim_size() > 1) { + // Get axes value + TensorShapeVector axes; + InlinedHashSet axes_set; + + // Note: Starting from opset 13, "axes" is provided as a second input to the Squeeze operator. + // In opset 11 and earlier, "axes" is defined as a node attribute instead. + if (node_.InputDefs().size() > 1) { + const auto* input_1 = node_.InputDefs()[1]; + ORT_TRY { + ORT_RETURN_IF_ERROR(get_initialized_input_values_func_(input_1->Name(), axes)); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + LOGS(logger_, ERROR) << ex.what(); + LOGS(logger_, INFO) << "Skip Squeeze op custom data propagation."; + }); + return Status::OK(); + } + } else { + const auto& attrs = node_.GetAttributes(); + auto it = attrs.find("axes"); + if (it != attrs.end()) { + const auto& axes_attr = it->second; + for (const auto& i : axes_attr.ints()) { + axes.push_back(i); + } + } + } + + ORT_TRY { + for (size_t i = 0; i < axes.size(); ++i) { + // Negative value means counting dimensions from the back. + axes_set.insert(HandleNegativeAxis(axes[i], tensor_shape_proto.dim_size())); + } + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + LOGS(logger_, ERROR) << ex.what(); + LOGS(logger_, INFO) << "Skip Squeeze op custom data propagation."; + }); + return Status::OK(); + } + + auto& inferred_shape_values = output_def_.GetMutableInferredShapeValues(); + + if (!inferred_shape_values.has_value()) { + inferred_shape_values.emplace(); + } + inferred_shape_values->clear_dim(); + + int64_t dim_index = 0; + for (const auto& dim : tensor_shape_proto.dim()) { + auto value = dim.dim_value(); + if (axes_set.size() > 0) { + if (axes_set.find(dim_index) == axes_set.end()) { + inferred_shape_values->add_dim()->set_dim_value(value); + } + } else { + if (value != 1) { + inferred_shape_values->add_dim()->set_dim_value(value); + } + } + + dim_index++; + } + } + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/squeeze_op_data_propagation.h b/onnxruntime/core/graph/data_propagation/squeeze_op_data_propagation.h new file mode 100644 index 0000000000000..15e1e8458525a --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/squeeze_op_data_propagation.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "custom_data_propagation.h" +#include "core/graph/graph.h" + +namespace onnxruntime { + +/** + * @brief Class to infer the output values/scalar for 'Squeeze' operator given the input is shape values. + * + */ +class SqueezeOpDataPropagation : public CustomDataPropagationBase { + public: + SqueezeOpDataPropagation(const Node& node, + NodeArg& output_def, + std::function func, + const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation, + const logging::Logger& logger) noexcept + : CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {} + + Status infer() override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/sub_op_data_propagation.cc b/onnxruntime/core/graph/data_propagation/sub_op_data_propagation.cc new file mode 100644 index 0000000000000..4ee8ab546e707 --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/sub_op_data_propagation.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sub_op_data_propagation.h" +#include "core/common/common.h" +#include "core/graph/node_arg.h" +#include "core/graph/onnx_protobuf.h" +#include "core/providers/common.h" + +namespace onnxruntime { + +Status SubOpDataPropagation::infer() { + // Get "A" input + const auto* input_0 = node_.InputDefs()[0]; + // Get "B" input + const auto* input_1 = node_.InputDefs()[1]; + + // Return and do nothing if input doesn't exist + if (!input_0 || !input_1 || !input_0->Exists() || !input_1->Exists()) { + return Status::OK(); + } + + if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) { + output_def_.SetInferredShapeScalarValue( + input_0->GetInferredShapeScalarValue().value() - + input_1->GetInferredShapeScalarValue().value()); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/sub_op_data_propagation.h b/onnxruntime/core/graph/data_propagation/sub_op_data_propagation.h new file mode 100644 index 0000000000000..a9be294b8f62f --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/sub_op_data_propagation.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "custom_data_propagation.h" +#include "core/graph/graph.h" + +namespace onnxruntime { + +/** + * @brief Class to infer the output scalar for 'Sub' operator given the input is a scalar related to shape. + * + * + * For example: + * + * (input with the shape as float32[1, 3, 64, 64]) + * | + * v + * Shape (It saves [1, 3, 64, 64] in inferred_shape_values_ in output's node_arg + * | during graph::SaveShapeValuesFromDataPropagation()) + * | + * | ______ + * | | + * v v + * Gather Gather (First 'Gather' saves 64 in inferred_scalar_value_ in output node_arg, and + * | | second 'Gather' saves 3 in inferred_scalar_value_ in output node_arg + * | | during GatherOpDataPropagation(), if the 'index' attributes + * | | are 2 and 1 respectively) + * \ / + * \ / + * | | + * v v + * Sub (It gets 64 from inferred_scalar_value_ in input A's node_arg and 3 from inferred_scalar_value_ + * | in input B's node_arg, then performs sub operation to get 61 and saves in inferred_scalar_value_ + * | in output's node_arg) + * v + * ... + */ +class SubOpDataPropagation : public CustomDataPropagationBase { + public: + SubOpDataPropagation(const Node& node, + NodeArg& output_def, + std::function func, + const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation, + const logging::Logger& logger) noexcept + : CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {} + + Status infer() override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/unsqueeze_op_data_propagation.cc b/onnxruntime/core/graph/data_propagation/unsqueeze_op_data_propagation.cc new file mode 100644 index 0000000000000..ff5ef6853bc78 --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/unsqueeze_op_data_propagation.cc @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "unsqueeze_op_data_propagation.h" +#include "core/common/common.h" +#include "core/graph/node_arg.h" +#include "core/graph/onnx_protobuf.h" +#include "core/providers/common.h" +#include "core/common/inlined_containers.h" + +namespace onnxruntime { + +Status UnsqueezeOpDataPropagation::infer() { + const auto* input_0 = node_.InputDefs()[0]; + + // Return and do nothing if input doesn't exist + if (!input_0 || !input_0->Exists()) { + return Status::OK(); + } + + auto dim_size = output_from_onnx_op_data_propagation_.tensor_type().shape().dim_size(); + + if (dim_size == 0 && input_0->GetInferredShapeScalarValue().has_value()) { + // Following code expands a scalr to one dimension array, e.g. shape data is 64 -> it becomes [64] + // In this case, the axis should be 0 + auto& inferred_shape_values = output_def_.GetMutableInferredShapeValues(); + + if (!inferred_shape_values.has_value()) { + inferred_shape_values.emplace(); + } + inferred_shape_values->clear_dim(); + + inferred_shape_values->add_dim()->set_dim_value(input_0->GetInferredShapeScalarValue().value()); + } else if (input_0->GetInferredShapeValues().has_value()) { + const auto& tensor_shape_proto = input_0->GetInferredShapeValues().value(); + + // The TensorShapeProto (inferred shape values) should have rank > 0 and + // all the dimensions have values (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { + for (const auto& dim : tensor_shape_proto.dim()) { + if (!dim.has_dim_value()) { + return Status::OK(); + } + } + } + + if (tensor_shape_proto.dim_size() > 0) { + // Get axes value + TensorShapeVector axes; + InlinedHashSet axes_set; + + // Note: Starting from opset 13, "axes" is provided as a second input to the Squeeze operator. + // In opset 11 and earlier, "axes" is defined as a node attribute instead. + if (node_.InputDefs().size() > 1) { + const auto* input_1 = node_.InputDefs()[1]; + ORT_TRY { + ORT_RETURN_IF_ERROR(get_initialized_input_values_func_(input_1->Name(), axes)); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + LOGS(logger_, ERROR) << ex.what(); + LOGS(logger_, INFO) << "Skip Unsqueeze op custom data propagation."; + }); + return Status::OK(); + } + } else { + const auto& attrs = node_.GetAttributes(); + auto it = attrs.find("axes"); + if (it != attrs.end()) { + const auto& axes_attr = it->second; + for (const auto& i : axes_attr.ints()) { + axes.push_back(i); + } + } + } + + // axes is required, if not provided just do nothing and return. + if (axes.empty()) { + return Status::OK(); + } + ORT_TRY { + for (size_t i = 0; i < axes.size(); ++i) { + // Negative value means counting dimensions from the back. + axes_set.insert(HandleNegativeAxis(axes[i], tensor_shape_proto.dim_size())); + } + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + LOGS(logger_, ERROR) << ex.what(); + LOGS(logger_, INFO) << "Skip Unsqueeze op custom data propagation."; + }); + return Status::OK(); + } + + auto& inferred_shape_values = output_def_.GetMutableInferredShapeValues(); + + if (!inferred_shape_values.has_value()) { + inferred_shape_values.emplace(); + } + inferred_shape_values->clear_dim(); + + int64_t axis = 0; + for (const auto& dim : tensor_shape_proto.dim()) { + if (axes_set.find(axis) != axes_set.end()) { + inferred_shape_values->add_dim()->set_dim_value(1); + } + + auto value = dim.dim_value(); + inferred_shape_values->add_dim()->set_dim_value(value); + + axis += 1; + } + } + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/data_propagation/unsqueeze_op_data_propagation.h b/onnxruntime/core/graph/data_propagation/unsqueeze_op_data_propagation.h new file mode 100644 index 0000000000000..26b6587aa93fc --- /dev/null +++ b/onnxruntime/core/graph/data_propagation/unsqueeze_op_data_propagation.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "custom_data_propagation.h" +#include "core/graph/graph.h" + +namespace onnxruntime { +/** + * @brief Class to infer the output values/scalar for 'Unsqueeze' operator given the input is shape values. + * + */ +class UnsqueezeOpDataPropagation : public CustomDataPropagationBase { + public: + UnsqueezeOpDataPropagation(const Node& node, + NodeArg& output_def, + std::function func, + const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation, + const logging::Logger& logger) noexcept + : CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {} + + Status infer() override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 3d67314cf693a..15e7fad0d4a1a 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -16,13 +16,13 @@ #include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" +#include "core/providers/common.h" #include "core/flatbuffers/flatbuffers_utils.h" #include "core/framework/error_code_helper.h" #include "core/framework/tensor_type_and_shape.h" #include "core/flatbuffers/schema/ort.fbs.h" #include "core/framework/tensor_external_data_info.h" #include "core/framework/tensor_shape.h" -#include "core/framework/tensor_type_and_shape.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/graph/function_utils.h" @@ -37,6 +37,7 @@ #include "core/graph/node_attr_utils.h" #include "core/graph/op.h" #include "core/graph/runtime_optimization_record_container.h" +#include "data_propagation/custom_data_propagation.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/graph/function.h" @@ -44,6 +45,7 @@ #include "core/graph/schema_registry.h" #include "onnx/checker.h" #include "onnx/defs/parser.h" +#include "onnx/defs/tensor_proto_util.h" using namespace ONNX_NAMESPACE::checker; #endif @@ -1231,28 +1233,6 @@ Graph::Graph(const Model& owning_model, ArgNameToTypeMap name_to_type_map; const auto& model_path = ModelPath(); - // If the tensor proto data is large enough, move data from TensorProto to an OrtValue - // - Add external data reference to TensorProto that points to an OrtValue. - // This lambda should not be used on initializers that already have external data reference. - // Otherwise, this function does nothing. - auto put_large_tensor_in_ort_value = [this, &model_path](ONNX_NAMESPACE::TensorProto& tensor_proto) { - size_t size_in_bytes = 0; - ORT_THROW_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes)); - if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) { - OrtValue ort_value; - ORT_THROW_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto, - CPUAllocator::DefaultInstance(), ort_value)); - constexpr const bool use_tensor_buffer_true = true; - auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get(), tensor_proto.name(), - use_tensor_buffer_true); - assert(ort_value.IsAllocated()); - auto ins_result = ortvalue_initializers_.insert_or_assign(tensor_proto_to_add.name(), std::move(ort_value)); - ORT_ENFORCE(ins_result.second, "Unexpected duplicate insert or assign OrtValue for tensor: ", tensor_proto_to_add.name(), - " in the initializer list."); - tensor_proto = std::move(tensor_proto_to_add); - } - }; - // Process 'Constant' nodes // Put the 'TensorProto' stored in the 'Constant' nodes attribute into the graphs initializer list for (auto& node : graph_proto_->node()) { @@ -1272,8 +1252,6 @@ Graph::Graph(const Model& owning_model, } } - put_large_tensor_in_ort_value(*tensor); - // Ensure initializers are also graph inputs. if (ir_version_ < 4) { TypeProto t{utils::TypeProtoFromTensorProto(*tensor)}; @@ -1350,25 +1328,7 @@ Graph::Graph(const Model& owning_model, } // Copy initial tensors to a map. - for (int i = 0, lim = graph_proto_->initializer_size(); i < lim; ++i) { - auto& tensor = *graph_proto_->mutable_initializer(i); - // If data is on disk, it will be loaded either by optimizers - // or during session state finalization. - // If data is already in memory, do nothing. - if (!utils::HasExternalData(tensor)) { - // sparse_tensor_names_ contain references to strings to save memory - // in case we replace the tensor_proto, we want to make sure we remove - // the old reference first, and then add a new one. - const bool is_sparse = sparse_tensor_names_.count(tensor.name()); - if (is_sparse) { - sparse_tensor_names_.erase(tensor.name()); - } - put_large_tensor_in_ort_value(tensor); - if (is_sparse) { - sparse_tensor_names_.emplace(tensor.name()); - } - } - + for (auto& tensor : graph_proto_->initializer()) { auto p = name_to_initial_tensor_.emplace(tensor.name(), &tensor); if (!p.second) { LOGS(logger_, WARNING) << "Duplicate initializer (dense, sparse or ConstantNode): '" << tensor.name() @@ -2717,8 +2677,8 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { if (!def) return nullptr; - // only return data if it's for a constant initializer. checks for outer scope initializers - // if this is a subgraph and the name isn't found locally. + // Returns if it's a constant initializer. + // Checks for outer scope initializers if this is a subgraph and the name isn't found locally. const TensorProto* initializer = graph_.GetConstantInitializer(def->Name(), true); if (initializer != nullptr) { // Check if this is in-memory external data (data stored in OrtValue) @@ -2740,8 +2700,60 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { " has in-memory external data but cannot get OrtValue during shape inference"); } } + + return initializer; + } + + // The following code handles cases where a node stores the previously inferred output shape values in its NodeArg. + // + // For example, the Reshape operator, its shape input may come from a producer node such as a Shape operator, + // and the inferred output shape value is already stored as a TensorShapeProto in corresponding NodeArg. + // + // In such cases, the Reshape operator should convert this TensorShapeProto into a TensorProto. + // The resulting TensorProto will then be treated as an initializer during ONNX shape inference, + // allowing the real dimension values to be correctly used. + + const auto& inferred_shape_values = def->GetInferredShapeValues(); + + // Converts the inferred shape values if any to a TensorProto and returns the TensorProto. + if (inferred_shape_values.has_value() && inferred_shape_values->dim_size() > 0) { + TensorProto tensor_proto; + tensor_proto.set_data_type(TensorProto_DataType_INT64); + tensor_proto.add_dims(inferred_shape_values->dim_size()); + bool all_values = true; + for (const auto& dim : inferred_shape_values->dim()) { + if (dim.has_dim_value()) { + tensor_proto.add_int64_data(dim.dim_value()); + } else { + all_values = false; + break; + } + } + + if (all_values) { + temp_tensor_protos_.push_back(std::make_unique(std::move(tensor_proto))); + return temp_tensor_protos_.back().get(); + } } - return initializer; + + const std::optional inferred_shape_scalar_value = def->GetInferredShapeScalarValue(); + + // Converts the inferred shape scalar value if any to a TensorProto and returns the TensorProto. + // + // Note: ONNX's getShapeInput() internally calls getInputData() to retrieve a TensorProto (if available) + // and then extracts shape/dimension values from it. As a result, the scalar value may not be + // properly handled and propagated in ONNX's shape inference. + // However, Graph::SaveShapeValuesFromDataPropagation() properly handles data propagation for + // some operators. + if (inferred_shape_scalar_value.has_value()) { + TensorProto tensor_proto; + tensor_proto.set_data_type(TensorProto_DataType_INT64); + tensor_proto.add_int64_data(inferred_shape_scalar_value.value()); + temp_tensor_protos_.push_back(std::make_unique(std::move(tensor_proto))); + return temp_tensor_protos_.back().get(); + } + + return nullptr; } // ORT does not implement partial data propagation yet so just return nullptr. @@ -2784,9 +2796,223 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { // These need to outlive the shape inference call, so we store them here // Inference is per node and the instance of this context is on the stack, // so this is safe. + // It can also be used to temporarily save the inferred shape values as a TensorProto. mutable InlinedVector> temp_tensor_protos_; }; +// An implementation of the DataPropagationContext interface optional by operator-specific +// shape inference for onnxruntime graphs. +// Please see the description and usage of ONNX's data propagation here: +// https://github.com/onnx/onnx/blob/main/onnx/defs/shape_inference.h#L117-L127 +class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext { + public: + DataPropagationContextImpl(Node& node) noexcept : node_(node) { + node_output_types_.resize(node.OutputDefs().size()); + } + + const AttributeProto* getAttribute(const std::string& name) const override { + auto& attribute_value_map = node_.GetAttributes(); + auto iter = attribute_value_map.find(name); + if (iter == attribute_value_map.end()) { + return nullptr; + } + return &iter->second; + } + + size_t getNumInputs() const noexcept override { + return node_.InputDefs().size(); + } + + const TypeProto* getInputType(size_t index) const override { + if (index >= getNumInputs()) { + return nullptr; + } + + const TypeProto* type = nullptr; + auto p_node_arg = node_.InputDefs().at(index); + if ((nullptr != p_node_arg) && p_node_arg->Exists()) { + type = p_node_arg->TypeAsProto(); + } + + return type; + } + + size_t getNumOutputs() const noexcept override { + return node_output_types_.size(); + } + + const TypeProto* getOutputType(size_t index) const override { + if (index >= getNumOutputs()) { + return nullptr; + } + + return &node_output_types_[index]; + } + + const TensorShapeProto* getInputData(size_t index) override { + if (index >= getNumInputs()) { + return nullptr; + } + + auto def = node_.InputDefs()[index]; + if (!def) + return nullptr; + + // Get the previously inferred shape values that stored in NodeArg's inferred_shape_values_ if any. + // Note: getInputData() only supports input data (shape values) that is a tensor not a scalar, + // becase the returning TensorShapeProto can't store scalar value. Therefore, op's data propagation + // defined in ONNX Op schema does not support scalar output. + // However, Graph::SaveShapeValuesFromDataPropagation() does support output scalar value for + // some operators. + const auto& tensor_shape_proto = def->GetInferredShapeValues(); + if (tensor_shape_proto.has_value()) { + return &*tensor_shape_proto; + } + + return nullptr; + } + + void addOutputData(size_t index, TensorShapeProto&& tsp) override { + if (index >= node_output_types_.size()) return; + + TypeProto& type_proto = node_output_types_[index]; + *type_proto.mutable_tensor_type()->mutable_shape() = std::move(tsp); + } + + void RunInferencing() { + auto* schema = node_.Op(); + if (nullptr != schema) { + schema->GetDataPropagationFunction()(*this); + } + } + + const std::vector& InferredOutputTypes() const { return node_output_types_; } + + private: + Node& node_; + std::vector node_output_types_; +}; + +Status Graph::SaveShapeValuesFromDataPropagation(const Node& node, + NodeArg& output_def, + const TypeProto& onnx_inferred_type_after_data_propagation) const { + // Helper function to get the input value if it's a initializer. + auto get_initialized_input_values_func = [&](const std::string& input_name, TensorShapeVector& input_values) + -> Status { + const TensorProto* initializer = this->GetConstantInitializer(input_name, true); + + if (initializer) { + // Get shape from TensorProto as well as element counts. + // If shape has dimension size equals zero, it means it's a scalar and has only one element. + auto tensor_shape = utils::GetTensorShapeFromTensorProto(*initializer); + size_t element_cnt = narrow(tensor_shape.Size()); + + // Check if this is in-memory external data (data stored in OrtValue) + if (utils::HasExternalDataInMemory(*initializer)) { + // Try to get the OrtValue for this initializer + OrtValue ort_value; + if (this->GetOrtValueInitializer(input_name, ort_value, true)) { + const Tensor& tensor = ort_value.Get(); + if (initializer->data_type() == TensorProto_DataType_INT32) { + auto data_span = tensor.DataAsSpan(); + ORT_ENFORCE(data_span.size() == element_cnt, + "The element counts from Tensor should be the same" + "from using utils::GetTensorShapeFromTensorProto()"); + + size_t index = 0; + input_values.resize(element_cnt); + for (const auto& v : data_span) { + input_values[index] = static_cast(v); + ++index; + } + } else if (initializer->data_type() == TensorProto_DataType_INT64) { + const int64_t* src = tensor.Data(); + memcpy(input_values.data(), src, element_cnt * sizeof(int64_t)); + } + } else { + // If we can't get the OrtValue, it is a bug + ORT_THROW("Initializer ", input_name, + " has in-memory external data but cannot get OrtValue during shape inference"); + } + } + // Unpack tensor from raw data, external data (not in memory) or the type specific data field + else { + if (initializer->data_type() == TensorProto_DataType_INT32) { + InlinedVector tmp_values; + tmp_values.resize(element_cnt); + ORT_RETURN_IF_ERROR(utils::UnpackTensor(*initializer, + this->ModelPath(), + tmp_values.data(), + element_cnt)); + + input_values.resize(element_cnt); + for (size_t i = 0; i < element_cnt; ++i) { + input_values[i] = static_cast(tmp_values[i]); // copy values + } + } else if (initializer->data_type() == TensorProto_DataType_INT64) { + input_values.resize(element_cnt); + ORT_RETURN_IF_ERROR(utils::UnpackTensor(*initializer, + this->ModelPath(), + input_values.data(), + element_cnt)); + } + } + } + + return Status::OK(); + }; + + // For certain operators (e.g., Size, Squeeze, Unsqueeze), ONNX's + // PartialDataPropagationFunction() does not always produce complete or accurate + // inferred shape values. + // + // In particular: + // - Scalar inputs and outputs are not handled correctly. + // - Some operators require additional logic that is not covered by the default function. + // + // Therefore, for these cases, we perform custom data propagation to ensure + // correct and complete inference. + auto dp = CreateCustomDataPropagation(node, output_def, + get_initialized_input_values_func, + onnx_inferred_type_after_data_propagation, + logger_); + if (dp) { + ORT_RETURN_IF_ERROR(dp->infer()); + return Status::OK(); + } + + // If no custom data propagation is defined for the operator, + // fall back to using the result of ONNX's PartialDataPropagationFunction(), if available. + + int dim_size = 0; + if (onnx_inferred_type_after_data_propagation.has_tensor_type() && + onnx_inferred_type_after_data_propagation.tensor_type().has_shape()) { + dim_size = onnx_inferred_type_after_data_propagation.tensor_type().shape().dim_size(); + } + + if (dim_size > 0) { + // Only handle that the inferred shape values (from ONNX operator's PartialDataPropagationFunction() ) has rank > 0 + // and all dimensions have concrete (non-symbolic) values. + for (int i = 0; i < dim_size; ++i) { + if (!onnx_inferred_type_after_data_propagation.tensor_type().shape().dim(i).has_dim_value()) { + return Status::OK(); + } + } + + if (!output_def.inferred_shape_values_.has_value()) { + output_def.inferred_shape_values_.emplace(); + } + + output_def.inferred_shape_values_->clear_dim(); + for (int i = 0; i < dim_size; ++i) { + auto value = onnx_inferred_type_after_data_propagation.tensor_type().shape().dim(i).dim_value(); + output_def.inferred_shape_values_->add_dim()->set_dim_value(value); + } + } + + return Status::OK(); +} + Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, const std::vector& input_types, std::vector& output_types, @@ -2978,11 +3204,20 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso // returned here. SubgraphInferencingFunc func(Graph::InferAndVerifySubgraphTypes); InferenceContextImpl context(node, func, *this, options); + DataPropagationContextImpl data_propagation_context(node); { auto status = Status::OK(); ORT_TRY { context.RunInferencing(); + + // Calling an operator's TypeAndShapeInferenceFunction() alone is sometimes insufficient + // for complete shape inference. For example, the Shape operator only provides the + // output's rank (1-dimensional) but not its actual dimension values. + // The PartialDataPropagationFunction(), defined in the ONNX operator schema, must also + // be executed to obtain the concrete output shape values, allowing accurate propagation + // of shape information throughout the graph. + data_propagation_context.RunInferencing(); } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { @@ -2994,6 +3229,8 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso const auto& onnx_inferred_types(context.InferredOutputTypes()); + const auto& onnx_inferred_types_after_data_propagation(data_propagation_context.InferredOutputTypes()); + // Infer and verify node output arg type information. int i = -1; for (auto& output_def : node.MutableDefinitions().output_defs) { @@ -3010,6 +3247,12 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso auto op_formal_parameter = op.outputs().at(operand_index); const TypeProto& onnx_inferred_type = onnx_inferred_types[i]; + const TypeProto& onnx_inferred_type_after_data_propagation = onnx_inferred_types_after_data_propagation[i]; + + ORT_RETURN_IF_ERROR(SaveShapeValuesFromDataPropagation(node, + *output_def, + onnx_inferred_type_after_data_propagation)); + DataType existing_type = output_def->Type(); DataType inferred_type = nullptr; @@ -3322,6 +3565,26 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { } } + ORT_RETURN_IF_ERROR(CleanUpShapeValuesFromDataPropagation()); + + return Status::OK(); +} + +Status Graph::CleanUpShapeValuesFromDataPropagation() { + for (auto node_index : nodes_in_topological_order_) { + auto& node = *GetNode(node_index); + + for (auto node_arg : node.MutableInputDefs()) { + node_arg->inferred_shape_values_.reset(); + node_arg->inferred_scalar_value_.reset(); + } + + for (auto node_arg : node.MutableOutputDefs()) { + node_arg->inferred_shape_values_.reset(); + node_arg->inferred_scalar_value_.reset(); + } + } + return Status::OK(); } @@ -3457,6 +3720,38 @@ Status Graph::Resolve(const ResolveOptions& options) { return ForThisAndAllSubgraphs(all_subgraphs, finalize_func); } +Status Graph::ConvertInitializersIntoOrtValues() { + std::vector all_subgraphs; + FindAllSubgraphs(all_subgraphs); + + auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status { + // if we have any initializers that are not in memory, put them there. + const auto& model_path = graph.ModelPath(); + auto& graph_proto = *graph.graph_proto_; + for (int i = 0, lim = graph_proto.initializer_size(); i < lim; ++i) { + auto& tensor_proto = *graph_proto.mutable_initializer(i); + if (utils::HasExternalData(tensor_proto)) { + continue; // ignore data on disk, that will be loaded either by EP or at session_state finalize + } + + size_t size_in_bytes = 0; + ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes)); + if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) { + OrtValue ort_value; + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto, + CPUAllocator::DefaultInstance(), ort_value)); + constexpr const bool use_tensor_buffer_true = true; + auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get(), tensor_proto.name(), + use_tensor_buffer_true); + ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(tensor_proto_to_add, ort_value)); + } + } + return Status::OK(); + }; + + return ForThisAndAllSubgraphs(all_subgraphs, put_weights_maybe_in_memory_func); +} + void Graph::SetName(const std::string& name) { graph_proto_->set_name(name); } diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 9d98a15d8457a..248c6d74e6cbd 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -634,6 +634,7 @@ MlasGemm( { MlasGemmBatch(Shape, &DataParams, 1, ThreadPool); } + /** * @brief Parameters that define the shape of a dynamically quantized GEMM operation. * @@ -646,6 +647,7 @@ struct MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS { size_t N = 0; /**< Column size of matrix B */ size_t K = 0; /**< Column size of matrix A and Row size of matrix B */ }; + /** * @brief Parameters that define the data buffers and layout for a dynamic quant GEMM. * @@ -680,6 +682,14 @@ MlasDynamicQGemm ( MlasDynamicQGemmBatch(Shape, DataParams, 1, ThreadPool); } +/** + * @brief Determines whether a dynamic quantized GEMM implementation is available on the current platform. + * + * MlasDynamicQGemm() and MlasDynamicQGemmBatch() should only be called if this function returns true. + */ +bool +MLASCALL +MlasIsDynamicQGemmAvailable(); // // Symmetric QGEMM has limited buffer overrun. diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index a1c2e467188f7..186dc81d7b7b7 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -201,6 +201,17 @@ MlasGemmBatch( }); } +bool +MLASCALL +MlasIsDynamicQGemmAvailable() +{ +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + return ArmKleidiAI::UseSME2; +#else + return false; +#endif +} + void MLASCALL MlasDynamicQGemmBatch ( @@ -209,11 +220,11 @@ MlasDynamicQGemmBatch ( const size_t BatchN, MLAS_THREADPOOL* ThreadPool ) { + assert(MlasIsDynamicQGemmAvailable()); + #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) - //No fallback and putting in guards. This implementation is SME2 specific. - if(ArmKleidiAI::UseSME2){ - ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); - } + //No fallback + ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); #endif MLAS_UNREFERENCED_PARAMETER(Shape); @@ -332,13 +343,13 @@ MlasDynamicQgemmPackBSize( size_t K ) { + assert(MlasIsDynamicQGemmAvailable()); + size_t bytes = 0; #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) //No fallback available //TODO: Insert Override - if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override - bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K); - } + bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K); #endif MLAS_UNREFERENCED_PARAMETER(N); @@ -405,11 +416,15 @@ Return Value: const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); - // If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for + // If this gemm B argument is used in a dynamically quantized gemm operation we can optimize for // this use case. Concat both packed representations for later decision. This allows for cases later - // where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization + // where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization // for better performance - return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K); + if (MlasIsDynamicQGemmAvailable()) { + return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K); + } else { + return AlignedBytesRequired; + } } void @@ -423,11 +438,11 @@ MlasDynamicQgemmPackB( void* PackedB ) { + assert(MlasIsDynamicQGemmAvailable()); + #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) //No fallback - if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override - ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB); - } + ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB); #endif MLAS_UNREFERENCED_PARAMETER(N); diff --git a/onnxruntime/core/optimizer/bias_softmax_fusion.cc b/onnxruntime/core/optimizer/bias_softmax_fusion.cc index 2bbc70db16cde..c37561c0086b0 100644 --- a/onnxruntime/core/optimizer/bias_softmax_fusion.cc +++ b/onnxruntime/core/optimizer/bias_softmax_fusion.cc @@ -44,7 +44,7 @@ bool TryBiasSoftmaxSubgraphMatch(Graph& graph, Node& start, Node*& add, Node*& s // check node is add and has single output if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) || - !graph_utils::IsSupportedProvider(node, {kCudaExecutionProvider, kRocmExecutionProvider}) || + !graph_utils::IsSupportedProvider(node, {kCudaExecutionProvider}) || !optimizer_utils::CheckOutputEdges(graph, node, 1)) { return false; } @@ -239,7 +239,7 @@ Status BiasSoftmaxFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve // only support GPU execution provider auto& cep = GetCompatibleExecutionProviders(); - if (cep.size() > 0 && cep.find(kCudaExecutionProvider) == cep.end() && cep.find(kRocmExecutionProvider) == cep.end()) + if (cep.size() > 0 && cep.find(kCudaExecutionProvider) == cep.end()) return Status::OK(); for (auto node_index : node_topology_list) { diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index 04f74eb860443..b7f5af5888be0 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -79,7 +79,7 @@ class ConvActivationSelector : public NodeSelector { return std::nullopt; } - auto is_supported_non_cuda_rocm_ep_activation = [&graph_viewer](const Node& activation_node) { + auto is_supported_non_cuda_ep_activation = [&graph_viewer](const Node& activation_node) { if (graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "Relu", {6, 13, 14}) || graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "Sigmoid", {6, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "Tanh", {6, 13}) || @@ -105,17 +105,13 @@ class ConvActivationSelector : public NodeSelector { // check EP type and activation if (node_ep == kCudaExecutionProvider) { return std::nullopt; - } else if (node_ep == kRocmExecutionProvider) { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) { - return std::nullopt; - } } else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider || node_ep == kWebGpuExecutionProvider) { - if (!is_supported_non_cuda_rocm_ep_activation(*next_node) && + if (!is_supported_non_cuda_ep_activation(*next_node) && !graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "HardSigmoid", {6})) { return std::nullopt; } } else { - if (!is_supported_non_cuda_rocm_ep_activation(*next_node)) { + if (!is_supported_non_cuda_ep_activation(*next_node)) { return std::nullopt; } } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 3680127ed4793..fdd4f5aa27862 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -76,7 +76,6 @@ #include "core/optimizer/quick_gelu_fusion.h" #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" -#include "core/optimizer/rocm_blas_alt_impl.h" #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" @@ -275,10 +274,6 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique()); } - // add __backwardpass attribute to nodes after YieldOp, ROCm-only - const InlinedHashSet rocm_ep = {onnxruntime::kRocmExecutionProvider}; - transformers.emplace_back(std::make_unique(rocm_ep)); - // run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around. // shouldn't affect the end result - just easier to debug any issue if it's last. transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); @@ -305,33 +300,26 @@ InlinedVector> GenerateTransformers( const InlinedHashSet cuda_eps = {onnxruntime::kCudaExecutionProvider}; - const InlinedHashSet cuda_rocm_eps = {onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider}; - const InlinedHashSet cpu_cuda_dml_rocm_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kDmlExecutionProvider}; - const InlinedHashSet cpu_acl_cuda_dml_rocm_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kDmlExecutionProvider}; - const InlinedHashSet cpu_rocm_acl_armnn_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kRocmExecutionProvider, + const InlinedHashSet cpu_cuda_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider}; + const InlinedHashSet cpu_cuda_dml_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kDmlExecutionProvider}; + const InlinedHashSet cpu_acl_cuda_dml_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kDmlExecutionProvider}; + const InlinedHashSet cpu_acl_armnn_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider, + onnxruntime::kWebGpuExecutionProvider}; + const InlinedHashSet cpu_cuda_acl_armnn_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider, onnxruntime::kJsExecutionProvider, onnxruntime::kWebGpuExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider, - onnxruntime::kJsExecutionProvider, - onnxruntime::kWebGpuExecutionProvider}; const InlinedHashSet cpu_dml_acl_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kDmlExecutionProvider, onnxruntime::kAclExecutionProvider}; @@ -362,30 +350,30 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_dml_acl_eps)); transformers.emplace_back(std::make_unique(cpu_acl_eps)); - transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_webgpu_eps)); - - transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); - transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_armnn_js_webgpu_eps)); + + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_eps, level)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_eps, level)); + transformers.emplace_back(std::make_unique(cpu_cuda_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_dml_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_eps)); transformers.emplace_back(std::make_unique(cuda_eps)); // Run MatMulAddFusion again after *AttentionFusion transforms with `preserve_attention_pattern = false`, // to cleanup the remaining MatMul-Add that were part of the attention pattern but not detected or fused. transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, false)); - transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_dml_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_eps)); // GeluApproximation has side effects which may change results. It needs to be manually enabled, // or alternatively the model can be updated offline using a model conversion script // e.g. fusion_gelu_approximation function used by onnxruntime/python/tools/transformers/onnx_model_bert.py if (enable_gelu_approximation) { - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_eps)); } #ifdef ENABLE_TRITON @@ -396,15 +384,15 @@ InlinedVector> GenerateTransformers( } #endif // ENABLE_TRITON - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_eps)); + transformers.emplace_back(std::make_unique(cuda_eps)); #ifdef ENABLE_TRAINING - transformers.emplace_back(std::make_unique(cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cuda_eps)); + transformers.emplace_back(std::make_unique(cuda_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_eps)); #endif - transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_eps)); transformers.emplace_back(std::make_unique(dml_ep)); #ifdef MLAS_TARGET_AMD64_IX86 diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 1e88ed44b1a8a..8a7f83e871768 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -633,9 +633,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr // if there is a Cast between x and y. Having Cast between means cannot fuse. const Node* p_pow_input_node = graph_utils::GetInputNode(pow_node, 0); bool has_leading_cast = false; - bool is_gpu_ep = (pow_node.GetExecutionProviderType() == kCudaExecutionProvider || - pow_node.GetExecutionProviderType() == kRocmExecutionProvider) || - skip_device_check_; + bool is_gpu_ep = pow_node.GetExecutionProviderType() == kCudaExecutionProvider || skip_device_check_; if (is_gpu_ep && p_pow_input_node) { Node& pow_input_node = *graph.GetNode(p_pow_input_node->Index()); // If input to Pow is a Cast, and the Cast has 2 consumers only (Pow, Div) diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index 7ceb61b4aabc5..cc222e5e342dc 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -198,7 +198,6 @@ bool IsMatMulInputTypeSupported(const Node& node) { // if no matching key is present, any data type is allowed static const InlinedHashMap> k_supported_data_types{ {kCudaExecutionProvider, {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}}, - {kRocmExecutionProvider, {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}}, {kCpuExecutionProvider, {"tensor(float)"}}, }; diff --git a/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc b/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc deleted file mode 100644 index decb25f565efe..0000000000000 --- a/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#include - -#include "core/optimizer/initializer.h" -#include "core/optimizer/rocm_blas_alt_impl.h" -#include "core/graph/graph_utils.h" - -using namespace ONNX_NAMESPACE; -using namespace ::onnxruntime::common; -namespace onnxruntime { - -Status RocmBlasAltImpl::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - GraphViewer graph_viewer(graph); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - - bool is_backward_pass = false; - - for (auto node_index : node_topology_list) { - auto& node = *graph.GetNode(node_index); - - if (node.OpType() == "YieldOp") { - is_backward_pass = true; - } - - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - - if (is_backward_pass) { - node.AddAttribute(std::string("__backwardpass"), static_cast(1)); - modified = true; - } - } - - return Status::OK(); -} -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/rocm_blas_alt_impl.h b/onnxruntime/core/optimizer/rocm_blas_alt_impl.h deleted file mode 100644 index 11744d0dac32b..0000000000000 --- a/onnxruntime/core/optimizer/rocm_blas_alt_impl.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/optimizer/graph_transformer.h" -#include "core/graph/graph_utils.h" - -namespace onnxruntime { - -class RocmBlasAltImpl : public GraphTransformer { - public: - RocmBlasAltImpl(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("RocmBlasAltImpl", compatible_execution_providers) {} - - Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 6cbbdd4e0a7ef..1eb03af3befa4 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -81,6 +81,10 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons ORT_UNUSED_PARAMETER(captureState); } +void Telemetry::LogCompileModel(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); +} + void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const { ORT_UNUSED_PARAMETER(session_id); diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index b60345e1b8a80..9c2859f7634b6 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -66,6 +66,8 @@ class Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const; + virtual void LogCompileModel(uint32_t session_id) const; + virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 2e5d334856278..693e265af46b1 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -334,6 +334,20 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio } } +void WindowsTelemetry::LogCompileModel(uint32_t session_id) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "CompileModel", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId")); +} + void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const { if (global_register_count_ == 0 || enabled_ == false) diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 261d14a7fed8c..044feec071223 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -59,6 +59,8 @@ class WindowsTelemetry : public Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const override; + void LogCompileModel(uint32_t session_id) const override; + void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const override; diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 9e49f068c680c..15baf7309070d 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -38,7 +38,7 @@ struct ProviderHostCPU { virtual Status NonMaxSuppressionBase__PrepareCompute(OpKernelContext* ctx, PrepareContext& pc) = 0; virtual Status NonMaxSuppressionBase__GetThresholdsFromInputs(const PrepareContext& pc, int64_t& max_output_boxes_per_class, float& iou_threshold, float& score_threshold) = 0; -#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) // From cpu/tensor/size.h virtual Status Size__Compute(const Size* p, OpKernelContext* context) = 0; @@ -254,7 +254,7 @@ struct ProviderHostCPU { extern ProviderHostCPU& g_host_cpu; -#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) namespace GatherElements { inline Status ValidateInputShapes(const TensorShape& input_data_shape, const TensorShape& indices_shape, @@ -336,7 +336,7 @@ inline Status ExecuteTritonOpByFuncName(OpKernelContext* p_ctx, const std::strin } // namespace contrib #endif // ENABLE_TRITON -#endif // USE_CUDA || USE_CUDA_PROVIDER_INTERFACE || USE_ROCM +#endif // USE_CUDA || USE_CUDA_PROVIDER_INTERFACE #endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/matmul_helper.h b/onnxruntime/core/providers/cpu/math/matmul_helper.h index d7275ee324756..9da7509eea2c6 100644 --- a/onnxruntime/core/providers/cpu/math/matmul_helper.h +++ b/onnxruntime/core/providers/cpu/math/matmul_helper.h @@ -23,7 +23,7 @@ inline void TensorShapeCopyDims(const TensorShape& shape, int64_t* dims, size_t class MatMulComputeHelper { public: // fill_offsets is to control if to fill offsets here. - // For CUDA/ROCM kernel when we can use GemmStridedBatched, we don't need to fill the offsets. + // For CUDA kernel when we can use GemmStridedBatched, we don't need to fill the offsets. Status Compute(const TensorShape& orig_left_shape, const TensorShape& orig_right_shape, bool transa = false, bool transb = false, bool trans_batch_a = false, bool trans_batch_b = false, diff --git a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression_helper.h b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression_helper.h index 5cfd1ecee602a..e20e9ce0c81c2 100644 --- a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression_helper.h +++ b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression_helper.h @@ -10,11 +10,6 @@ #define ORT_DEVICE __device__ #define HelperMin(a, b) _Min(a, b) #define HelperMax(a, b) _Max(a, b) -#elif defined(__HIPCC__) -#include "core/providers/rocm/cu_inc/common.cuh" -#define ORT_DEVICE __host__ __device__ -#define HelperMin(a, b) _Min(a, b) -#define HelperMax(a, b) _Max(a, b) #else #include #define ORT_DEVICE @@ -50,8 +45,6 @@ struct SelectedIndex { #ifdef __NVCC__ namespace cuda { -#elif defined(__HIPCC__) -namespace rocm { #endif namespace nms_helpers { @@ -151,7 +144,5 @@ inline bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t b } // namespace nms_helpers #ifdef __NVCC__ } // namespace cuda -#elif defined(__HIPCC__) -} // namespace rocm #endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cu_inc/binary_elementwise_impl.cuh b/onnxruntime/core/providers/cuda/cu_inc/binary_elementwise_impl.cuh index 1469f55f0bfda..9d84920f76df9 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/binary_elementwise_impl.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/binary_elementwise_impl.cuh @@ -194,13 +194,8 @@ void BinaryElementWiseNoBroadcastImpl( if (count == 0) // special case where there's a dim value of 0 in the output shape return; -#ifdef USE_ROCM - const int num_elements_per_thread = 2; - const int num_threads_per_block = 512; -#else const int num_elements_per_thread = GridDim::maxElementsPerThread; const int num_threads_per_block = GridDim::maxThreadsPerBlock; -#endif int blocksPerGrid = static_cast(CeilDiv(count, num_threads_per_block * num_elements_per_thread)); #define FUNC_CALL(NumElemT) \ @@ -237,13 +232,8 @@ void _BinaryElementWiseImpl( if (count == 0) // special case where there's a dim value of 0 in the output shape return; -#ifdef USE_ROCM - const int num_elements_per_thread = 2; - const int num_threads_per_block = 512; -#else const int num_elements_per_thread = GridDim::maxElementsPerThread; const int num_threads_per_block = GridDim::maxThreadsPerBlock; -#endif int blocksPerGrid = static_cast(CeilDiv(count, num_threads_per_block * num_elements_per_thread)); NumElemT N = static_cast(count); diff --git a/onnxruntime/core/providers/cuda/cu_inc/elementwise_impl.cuh b/onnxruntime/core/providers/cuda/cu_inc/elementwise_impl.cuh index 07a65bd252304..dc5b62dfbedf2 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/elementwise_impl.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/elementwise_impl.cuh @@ -8,13 +8,8 @@ namespace onnxruntime { namespace cuda { -#ifdef USE_ROCM -constexpr int kElementsPerThread = 2; -constexpr int kThreadsPerBlock = 512; -#else constexpr int kElementsPerThread = GridDim::maxElementsPerThread; constexpr int kThreadsPerBlock = GridDim::maxThreadsPerBlock; -#endif template __global__ void ElementwiseKernel(T* output_data, const FuncT functor, TIndex N) { diff --git a/onnxruntime/core/providers/cuda/cuda_provider_interface.cc b/onnxruntime/core/providers/cuda/cuda_provider_interface.cc index 6cd5368cb7341..9632ecba3d951 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_interface.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_interface.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "core/session/onnxruntime_c_api.h" -#if !defined(USE_ROCM) namespace onnxruntime { struct Provider; @@ -16,5 +15,3 @@ ORT_API(onnxruntime::Provider*, GetProvider) { return reinterpret_cast(onnxruntime::GetProvider()); } } - -#endif diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index fbee1841ae8d5..091f9af0a593e 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -250,7 +250,7 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitCudaNotificationOnDevice); // wait cuda notification on cpu ep stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitCudaNotificationOnHost); - if (!use_existing_stream) + if (!use_existing_stream) { stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_cuda_stream, ep_info](const OrtDevice& device) { CUDA_CALL_THROW(cudaSetDevice(device.Id())); cudaStream_t stream = nullptr; @@ -258,7 +258,8 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis // CUDA_CALL_THROW(cudaStreamCreate(&stream)); return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, true, nullptr, nullptr, ep_info); }); - else + stream_handle_registry.RegisterSetDeviceFn(device_type, [](OrtDevice::DeviceId id) { CUDA_CALL_THROW(cudaSetDevice(id)); }); + } else { stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_cuda_stream, external_stream, @@ -267,7 +268,7 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis ep_info](const OrtDevice& device) { return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, false, external_cudnn_handle, external_cublas_handle, ep_info); }); - stream_handle_registry.RegisterSetDeviceFn(device_type, [](OrtDevice::DeviceId id) { CUDA_CALL_THROW(cudaSetDevice(id)); }); + } } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.cuh b/onnxruntime/core/providers/cuda/math/topk_impl.cuh index 0cbc848be971c..9c9ed73079701 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/topk_impl.cuh @@ -485,7 +485,7 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, int64_t N, \ int64_t dimension) -// This file is causing excessive long compilation time in ROCm EP. Split all those compilations into multiple +// This file is causing excessive long compilation time. Split all those compilations into multiple // translation units to speed it up. TOPKIMPLE(TOPK_IMPL_TYPE); diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index 90b542beaaf26..887d11a49db46 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -436,10 +436,6 @@ void HostApplyLayerNorm( parallel_rows *= 2; } dim3 threads(warp_size, threads_y, 1); -#ifdef __HIP_PLATFORM_HCC__ - // Optimization for ROCm MI100 - threads.y = 1; -#endif const dim3 blocks(1, std::min(n1, maxGridY), 1); int nshared = threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index 5238c42387eb2..ec87bb86acdb4 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -53,11 +53,7 @@ void Fill(cudaStream_t stream, T* output, T value, int64_t count); */ template struct TArray { -#if defined(USE_ROCM) -#define TARRAY_CONSTRUCTOR_SPECIFIERS __host__ __device__ -#else #define TARRAY_CONSTRUCTOR_SPECIFIERS -#endif TARRAY_CONSTRUCTOR_SPECIFIERS TArray() = default; TARRAY_CONSTRUCTOR_SPECIFIERS TArray(const TArray&) = default; diff --git a/onnxruntime/core/providers/cuda/tensor/concat_impl.cu b/onnxruntime/core/providers/cuda/tensor/concat_impl.cu index 84e1e76fae8de..f369ead0f90c4 100644 --- a/onnxruntime/core/providers/cuda/tensor/concat_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/concat_impl.cu @@ -10,13 +10,8 @@ namespace onnxruntime { namespace cuda { namespace { -#ifdef USE_ROCM -constexpr int kNumElementsPerThread = 2; -constexpr int kNumThreadsPerBlock = 512; -#else constexpr int kNumElementsPerThread = GridDim::maxElementsPerThread; constexpr int kNumThreadsPerBlock = GridDim::maxThreadsPerBlock; -#endif } // namespace // concat dimension are same for all inputs diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu index 81acb81be5025..2c4d5e6403dd5 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu @@ -14,11 +14,7 @@ namespace onnxruntime { namespace cuda { namespace { -#ifdef USE_ROCM -constexpr int kThreadsPerBlock = 256; -#else constexpr int kThreadsPerBlock = GPU_WARP_SIZE * 4; -#endif constexpr int kThreadWorkSize = 4; // General case to compute the input(for Gather)/output(for Scatter) and indices data offset given the thread ID diff --git a/onnxruntime/core/providers/cuda/tensor/slice_impl.cu b/onnxruntime/core/providers/cuda/tensor/slice_impl.cu index df392b45e9d5e..84021a99a8606 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/slice_impl.cu @@ -9,13 +9,8 @@ namespace onnxruntime { namespace cuda { namespace { -#ifdef USE_ROCM -constexpr int kNumElementsPerThread = 2; -constexpr int kNumThreadsPerBlock = 512; -#else constexpr int kNumElementsPerThread = GridDim::maxElementsPerThread; constexpr int kNumThreadsPerBlock = GridDim::maxThreadsPerBlock; -#endif } // namespace template diff --git a/onnxruntime/core/providers/cuda/tensor/split.cc b/onnxruntime/core/providers/cuda/tensor/split.cc index 52775b2e8be7a..ca82387600085 100644 --- a/onnxruntime/core/providers/cuda/tensor/split.cc +++ b/onnxruntime/core/providers/cuda/tensor/split.cc @@ -76,7 +76,6 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const { auto input_dims = input_shape.GetDims(); auto output_dimensions{input_shape.AsShapeVector()}; -#ifndef USE_ROCM if (split_sizes.size() == 3 && ((axis + 1) == gsl::narrow_cast(input_shape.NumDimensions()))) { // we use (axis + 1) == num_dimensions to check if we are splitting on inner most axis. // only when split on inner axis and output size is 3, we can use Split3Inner. @@ -101,7 +100,6 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const { output2->MutableDataRaw(), input_dims); } -#endif CudaAsyncBuffer output_ptr(this, num_outputs); gsl::span output_ptr_span = output_ptr.CpuSpan(); diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.cu b/onnxruntime/core/providers/cuda/tensor/split_impl.cu index 6c2cdfe029a08..e8d26d5757bc0 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.cu @@ -10,13 +10,8 @@ namespace onnxruntime { namespace cuda { namespace { -#ifdef USE_ROCM -constexpr int kNumElementsPerThread = 2; -constexpr int kNumThreadsPerBlock = 512; -#else constexpr int kNumElementsPerThread = GridDim::maxElementsPerThread; constexpr int kNumThreadsPerBlock = GridDim::maxThreadsPerBlock; -#endif } // namespace template @@ -157,7 +152,6 @@ Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block return Status::OK(); } -#ifndef USE_ROCM template __global__ void _Split3InnerKernel(const int64_t size0_in_byte, const int64_t size1_in_byte, @@ -264,7 +258,6 @@ Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t return Status::OK(); } -#endif } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/tile_impl.cu b/onnxruntime/core/providers/cuda/tensor/tile_impl.cu index e3ef2965c5577..aaf54d276684a 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/tile_impl.cu @@ -7,13 +7,8 @@ namespace onnxruntime { namespace cuda { -#ifdef USE_ROCM -constexpr int num_elements_per_thread = 2; -constexpr int num_threads_per_block = 512; -#else constexpr int num_elements_per_thread = GridDim::maxElementsPerThread; constexpr int num_threads_per_block = GridDim::maxThreadsPerBlock; -#endif template __global__ void _UnRolledTileKernel(const size_t shape_rank, const TArray fdm_input_shape, diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc index 9ecabcad504b3..69fbbf19241df 100644 --- a/onnxruntime/core/providers/get_execution_providers.cc +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -50,14 +50,6 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] = true, #else false, -#endif - }, - { - kRocmExecutionProvider, -#ifdef USE_ROCM - true, -#else - false, #endif }, { diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index ef977161bcc37..26144e6ba3995 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -126,7 +126,7 @@ JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not) // activation -JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.402823e+38f, max, -3.402823e+38f) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.4028234663852886e+38f, max, -3.4028234663852886e+38f) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10) JSEP_KERNEL_IMPL(Clip, Clip) ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index c9cd6e21b4eba..4787f6a80e959 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -5,8 +5,6 @@ #include "core/providers/migraphx/gpu_data_transfer.h" #include "core/providers/migraphx/migraphx_call.h" -// If you make change below, please also update onnxruntime/core/providers/rocm/gpu_data_transfer.cc - namespace onnxruntime { bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc index 0baa8a1c67c67..f95a9f755a8bd 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -12,7 +12,7 @@ namespace onnxruntime { enum MIGraphXResource { - hip_stream_t = rocm_resource_offset + hip_stream_t = migraphx_resource_offset }; struct MIGraphXNotification : synchronize::Notification { diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index e2a8005aba1da..2bca587bf3cb9 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -766,7 +766,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx, NvExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, bool has_user_compute_stream, cudaStream_t stream) { // Only set device if user hasn't provided a compute stream - if (has_user_compute_stream) { + if (!has_user_compute_stream) { CUDA_CALL_THROW(cudaSetDevice(device_id)); (void)stream; } @@ -984,6 +984,17 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) stream_ = nullptr; // Will be created in compute function } + if (info.user_aux_stream_array != nullptr) { + if (info.auxiliary_streams <= 0) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Auxiliary streams must be greater than 0 when using external auxiliary streams")); + } + external_aux_streams_ = true; + aux_streams_ = reinterpret_cast(info.user_aux_stream_array); + } else { + external_aux_streams_ = false; + aux_streams_ = nullptr; + } + std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; // incase the EP context is dumped the engine cache has to be enabled @@ -1407,9 +1418,30 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs; + + // These maps store the inputs and outputs of the subgraph. + // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration + // to determine the final inputs and outputs of the subgraph. + std::unordered_map fused_inputs, fused_outputs; + + // This map stores the node's output that will be consumed by another node outside of this subgraph. + // So the node's output should be put into the subgraph's output list. + std::unordered_map fused_outputs_to_add; + + // This map stores the node's output that is original graph's output. + // So the node's output should be put into the subgraph's output list. + std::unordered_map graph_outputs_to_add; + std::unordered_set erased; + + // This is the relative ordering that ensures node's input or output being added to the 'fused_inputs', + // 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. + // Items added earlier receive a smaller order index than items added later. + // When constructing the final sub_graph's input or output lists, entries with smaller + // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; @@ -1428,7 +1460,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -1443,7 +1475,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -1464,39 +1496,33 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } else { output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; } - if (node_set.find(node_idx) != node_set.end()) { - const auto& iter = fused_inputs.find(output); - if (iter != fused_inputs.end()) { - fused_inputs.erase(iter); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } - fused_outputs[output] = output_order++; - } - } else { - fused_outputs_to_add[output] = output_order++; + + if (node_set.find(node_idx) == node_set.end()) { + // This output will be consumed by another node outside of this subgraph. + // So the output should be put into the subgraph's output list. + fused_outputs_to_add.insert({output, output_order++}); } } - } else { - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } - // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list - else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } + } - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - fused_outputs[output] = output_order++; - } + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + // Only when output is neither in input list nor erased list, + // and the output is consumed by another node, add the output to output list + fused_outputs.insert({output, output_order++}); } } + + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order++}); + } } } @@ -1654,11 +1680,8 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t SetAllGraphInputs(graph_build); } - auto status = graph_build.Resolve(); - if (!status.IsOK()) { - LOGS_DEFAULT(ERROR) << status.ErrorMessage(); - ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX graph resolve failed: " + status.ErrorMessage())); - } + ORT_THROW_IF_ERROR(graph_build.Resolve()); + // Add parent graph output to the subgraph int i = 0; std::vector subgraph_outputs; @@ -1705,41 +1728,38 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t auto model = graph_viewer->CreateModel(*GetLogger()); auto model_proto = model->ToProto(); - // ORT's default topological sort is using reversed DFS. - // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. - // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating - // the model proto that has different node ordering compared to original onnx model. - // save user provided external data in memory instead of writing to ModelProto // needed for models > 2GB std::vector userWeights; if (use_external_data_initializer_) { - auto c_api = Ort::GetApi(); - const InitializedTensorSet& allInitializers = graph_viewer->GetAllInitializedTensors(); + const auto& allInitializers = graph_viewer->GetAllInitializedTensors(); userWeights.reserve(allInitializers.size()); - for (auto& entry : allInitializers) { - OrtValue initializer_value; - auto* tp = entry.second; + for (const auto& [name, tp] : allInitializers) { if (utils::HasRawData(*tp)) { - userWeights.emplace_back(TensorrtUserWeights(tp->name(), tp->raw_data().data(), tp->raw_data().size())); - } else if (graph_viewer->GetOrtValueInitializer(tp->name(), initializer_value)) { - // the initializer was marked as external data by the ORT graph at load time since it was provided in memory - size_t size = 0; - const void* ptr = nullptr; - Ort::ThrowOnError(c_api.GetTensorSizeInBytes(&initializer_value, &size)); - Ort::ThrowOnError(c_api.GetTensorData(&initializer_value, &ptr)); - userWeights.emplace_back(tp->name(), ptr, size); + // Keep inits in memory instead of writing to ModelProto. + userWeights.emplace_back(name, tp->raw_data().data(), tp->raw_data().size()); } else if (utils::HasExternalDataInMemory(*tp)) { - // only copy and take ownership of the data if none of the above conditions are met - std::unique_ptr full_init; - ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); - userWeights.emplace_back(std::move(full_init->name()), std::move(full_init->raw_data())); + // the initializer was marked as external data by the ORT graph at load time since it was provided in memory + if (OrtValue v; graph_viewer->GetOrtValueInitializer(name, v)) { + Ort::ConstValue initializer_value{&v}; + const size_t size = initializer_value.GetTensorSizeInBytes(); + const void* ptr = initializer_value.GetTensorRawData(); + userWeights.emplace_back(name, ptr, size); + } else { + // only copy and take ownership of the data if none of the above conditions are met + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights.emplace_back(name, full_init->raw_data()); + } } } } + // ORT's default topological sort is using reversed DFS. + // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. + // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating + // the model proto that has different node ordering compared to original onnx model. graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/, !use_external_data_initializer_ /*include raw initializers*/); - model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; @@ -2567,30 +2587,27 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // exclude weights if external std::vector userWeights; if (use_external_data_initializer_) { - auto c_api = Ort::GetApi(); const InitializedTensorSet& allInitializers = graph_body_viewer.GetAllInitializedTensors(); userWeights.reserve(allInitializers.size()); - for (auto& entry : allInitializers) { - OrtValue initializer_value; - auto* tp = entry.second; + for (const auto& [name, tp] : allInitializers) { if (utils::HasRawData(*tp)) { - userWeights.emplace_back(TensorrtUserWeights(tp->name(), tp->raw_data().data(), tp->raw_data().size())); - } else if (graph_body_viewer.GetOrtValueInitializer(tp->name(), initializer_value)) { - // the initializer was marked as external data by the ORT graph at load time since it was provided in memory - size_t size = 0; - const void* ptr = nullptr; - Ort::ThrowOnError(c_api.GetTensorSizeInBytes(&initializer_value, &size)); - Ort::ThrowOnError(c_api.GetTensorData(&initializer_value, &ptr)); - userWeights.emplace_back(tp->name(), ptr, size); + userWeights.emplace_back(name, tp->raw_data().data(), tp->raw_data().size()); } else if (utils::HasExternalDataInMemory(*tp)) { - // only copy and take ownership of the data if none of the above conditions are met - std::unique_ptr full_init; - ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); - userWeights.emplace_back(TensorrtUserWeights(std::move(full_init->name()), std::move(full_init->raw_data()))); + // the initializer was marked as external data by the ORT graph at load time since it was provided in memory + if (OrtValue v; graph_body_viewer.GetOrtValueInitializer(name, v)) { + Ort::ConstValue initializer_value{&v}; + const size_t size = initializer_value.GetTensorSizeInBytes(); + const void* ptr = initializer_value.GetTensorRawData(); + userWeights.emplace_back(name, ptr, size); + } else { + // only copy and take ownership of the data if none of the above conditions are met + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights.emplace_back(name, full_init->raw_data()); + } } } } - // ORT's default topological sort is using reversed DFS. // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating @@ -3033,6 +3050,11 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP select an optimization profile for the current context failed"); } + // Set auxiliary stream if provided by user + if (external_aux_streams_ && aux_streams_ != nullptr) { + trt_context->setAuxStreams(aux_streams_, (int32_t)auxiliary_streams_); + } + // Check before using trt_engine if (trt_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "No engine is found."); @@ -3444,6 +3466,11 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra } } + // Set auxiliary stream if provided by user + if (external_aux_streams_ && aux_streams_ != nullptr) { + trt_context->setAuxStreams(aux_streams_, (int32_t)auxiliary_streams_); + } + // Start CUDA graph capture with the correct stream // Note: We need to set the stream and start capture here because this is where we have access to the actual compute stream // Get the graph annotation ID that was stored during OnRunStart diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index bb8f687db094f..5c6ca20d75ec6 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -349,6 +349,8 @@ class NvExecutionProvider : public IExecutionProvider { mutable NvExecutionProviderInfo info_; bool external_stream_ = false; cudaStream_t stream_ = nullptr; + bool external_aux_streams_ = false; + cudaStream_t* aux_streams_ = nullptr; int max_partition_iterations_ = 1000; size_t min_subgraph_size_ = 1; size_t max_workspace_size_ = 0; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc index f25718114891b..74e16079a7cad 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc @@ -16,6 +16,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi const ConfigOptions& session_options) { NvExecutionProviderInfo info{}; void* user_compute_stream = nullptr; + void* user_aux_stream_array = nullptr; void* onnx_bytestream = nullptr; void* external_data_bytestream = nullptr; ORT_THROW_IF_ERROR( @@ -41,8 +42,17 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi user_compute_stream = reinterpret_cast(address); return Status::OK(); }) + .AddValueParser( + nv::provider_option_names::kUserAuxStreamArray, + [&user_aux_stream_array](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + user_aux_stream_array = reinterpret_cast(address); + return Status::OK(); + }) .AddAssignmentToReference(nv::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) .AddAssignmentToReference(nv::provider_option_names::kMaxSharedMemSize, info.max_shared_mem_size) + .AddAssignmentToReference(nv::provider_option_names::kLengthAuxStreamArray, info.auxiliary_streams) .AddAssignmentToReference(nv::provider_option_names::kDumpSubgraphs, info.dump_subgraphs) .AddAssignmentToReference(nv::provider_option_names::kDetailedBuildLog, info.detailed_build_log) .AddAssignmentToReference(nv::provider_option_names::kProfilesMinShapes, info.profile_min_shapes) @@ -56,6 +66,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi info.user_compute_stream = user_compute_stream; info.has_user_compute_stream = (user_compute_stream != nullptr); + info.user_aux_stream_array = user_aux_stream_array; info.onnx_bytestream = onnx_bytestream; info.external_data_bytestream = external_data_bytestream; @@ -98,8 +109,10 @@ ProviderOptions NvExecutionProviderInfo::ToProviderOptions(const NvExecutionProv {nv::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {nv::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, {nv::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, + {nv::provider_option_names::kUserAuxStreamArray, MakeStringWithClassicLocale(reinterpret_cast(info.user_aux_stream_array))}, {nv::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)}, {nv::provider_option_names::kMaxSharedMemSize, MakeStringWithClassicLocale(info.max_shared_mem_size)}, + {nv::provider_option_names::kLengthAuxStreamArray, MakeStringWithClassicLocale(info.auxiliary_streams)}, {nv::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)}, {nv::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.detailed_build_log)}, {nv::provider_option_names::kProfilesMinShapes, MakeStringWithClassicLocale(info.profile_min_shapes)}, diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h index 372e8196f38c2..26f392ad446a3 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h @@ -21,6 +21,7 @@ struct NvExecutionProviderInfo { int device_id{0}; bool has_user_compute_stream{false}; void* user_compute_stream{nullptr}; + void* user_aux_stream_array{nullptr}; int max_partition_iterations{1000}; int min_subgraph_size{1}; size_t max_workspace_size{0}; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index c3fbccef84883..e5015e705958d 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -2,20 +2,23 @@ // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared_library/provider_api.h" -#include "nv_provider_factory.h" +#include #include -#include "nv_execution_provider.h" -#include "nv_provider_factory_creator.h" -#include "nv_data_transfer.h" -#include "nv_allocator.h" + +#include "core/providers/shared_library/provider_api.h" #include "core/framework/provider_options.h" +#include "core/framework/plugin_ep_stream.h" #include "core/providers/nv_tensorrt_rtx/nv_provider_options.h" #include "core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.h" -#include #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/cuda_stream_handle.h" +#include "nv_provider_factory.h" +#include "nv_execution_provider.h" +#include "nv_provider_factory_creator.h" +#include "nv_data_transfer.h" +#include "nv_allocator.h" + using namespace onnxruntime; namespace onnxruntime { @@ -199,6 +202,7 @@ struct NvTrtRtxOrtAllocator : OrtAllocator { NvTrtRtxOrtAllocator(const OrtMemoryInfo* mem_info, const OrtApi& api) : memory_info_{mem_info} { version = ORT_API_VERSION; Alloc = AllocImpl; + AllocOnStream = AllocOnStreamImpl; Free = FreeImpl; Info = InfoImpl; Reserve = AllocImpl; // no special behavior for Reserve so use AllocImpl @@ -223,6 +227,11 @@ struct NvTrtRtxOrtAllocator : OrtAllocator { return impl.allocator_->Alloc(size); } + static void* ORT_API_CALL AllocOnStreamImpl(struct OrtAllocator* this_, size_t size, OrtSyncStream* stream) { + auto& impl = *static_cast(this_); + return impl.allocator_->AllocOnStream(size, stream); + } + static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) { auto& impl = *static_cast(this_); impl.allocator_->Free(p); diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index b0d850ca04841..97f80478f6f8c 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -9,7 +9,7 @@ // The functions are typically implemented in // onnxruntime/core/providers//_provider_factory.cc. // -// For execution providers that are built as separate libraries (CUDA, TensorRT, ROCm, MIGraphX, DNNL, OpenVINO) +// For execution providers that are built as separate libraries (CUDA, TensorRT, MIGraphX, DNNL, OpenVINO) // the functions are implemented in provider_bridge_ort.cc. #include "core/providers/cpu/cpu_provider_factory_creator.h" @@ -62,10 +62,6 @@ #include "core/providers/rknpu/rknpu_provider_factory_creator.h" #endif -#if defined(USE_ROCM) -#include "core/providers/rocm/rocm_provider_factory_creator.h" -#endif - #if defined(USE_QNN) || defined(USE_QNN_PROVIDER_INTERFACE) #include "core/providers/qnn/qnn_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 4d183b95bd938..0bb3accb4d754 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -76,6 +76,9 @@ Status BaseOpBuilder::ProcessDataTypes(QnnModelWrapper& qnn_model_wrapper, return CheckHtpDataTypes(input_qnn_dtypes, output_qnn_dtypes); } else if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) { return CheckGpuDataTypes(input_qnn_dtypes, output_qnn_dtypes); + } else if (IsIrBackend(qnn_model_wrapper.GetQnnBackendType())) { + // TODO: CheckIrDataTypes + return Status::OK(); } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Only support backend: CPU, HTP and GPU"); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index 6d03e9cd6c622..2f261d95e5b29 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -110,7 +110,7 @@ Status ConvOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, } } - // Validate that weight is signed type for per-channel quantization (required by QNN docs). + // Validate quantization axis for per-channel quantized weights. bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); if (is_npu_backend) { const auto& input_1 = inputs[1]; // weight diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index f3d81d7d2fdd7..9f28e2609faa1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -574,6 +574,10 @@ bool QnnOpConfigWrapper::CreateQnnGraphOp(const QNN_INTERFACE_VER_TYPE& qnn_inte return true; } +bool IsIrBackend(QnnBackendType backend_type) { + return backend_type == QnnBackendType::SERIALIZER; +} + bool IsNpuBackend(QnnBackendType backend_type) { return backend_type == QnnBackendType::HTP || backend_type == QnnBackendType::DSP; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 42f4d7bb60f34..77508f3934a20 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -96,6 +96,8 @@ enum class QnnBackendType : uint8_t { SERIALIZER, }; +bool IsIrBackend(QnnBackendType backend_type); + bool IsCpuBackend(QnnBackendType backend_type); bool IsNpuBackend(QnnBackendType backend_type); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 85901ab6fdfec..8973a4efa8ba1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -222,14 +222,14 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors()); if (Status::OK() != result) { - const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name(); + const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name() + ". " + result.ErrorMessage(); LOGS(logger, ERROR) << message; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false); if (Status::OK() != result) { - const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name(); + const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name() + ". " + result.ErrorMessage(); LOGS(logger, ERROR) << message; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 5be46cd480004..46f05ee40aa17 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -298,7 +298,6 @@ constexpr const char* kCannExecutionProvider = "CANNExecutionProvider"; constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider"; -constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider"; constexpr const char* kTensorrtExecutionProvider = "TensorrtExecutionProvider"; constexpr const char* kNvTensorRTRTXExecutionProvider = "NvTensorRTRTXExecutionProvider"; constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider"; @@ -318,9 +317,6 @@ std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const c std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name); std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name); -std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name); -std::unique_ptr CreateROCMPinnedAllocator(int16_t device_id, const char* name); - std::unique_ptr CreateGPUDataTransfer(); std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 0e5df0026d2c0..5732984af29b4 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -533,7 +533,7 @@ Status NonMaxSuppressionBase::GetThresholdsFromInputs(const PrepareContext& pc, Status GatherBase::PrepareForCompute(OpKernelContext* context, GatherBase::Prepare& p) const { return g_host_cpu.GatherBase__PrepareForCompute(this, context, reinterpret_cast(p)); } Status UnsqueezeBase::PrepareCompute(OpKernelContext* ctx, UnsqueezeBase::Prepare& p) const { return g_host_cpu.UnsqueezeBase__PrepareCompute(this, ctx, reinterpret_cast(p)); } -#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) bool TileOp::IsTileMemcpy(const TensorShape& input_shape, const int64_t* repeats, size_t rank, bool& is_batched_memcpy, size_t& num_of_elements_per_batch, size_t& num_of_copies_per_batch, size_t& num_of_batch_copies) { return g_host_cpu.TileOp__IsTileMemcpy(input_shape, repeats, rank, is_batched_memcpy, num_of_elements_per_batch, num_of_copies_per_batch, num_of_batch_copies); } diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index f1d545d0c6b17..d05c1c285a5f3 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -70,13 +70,27 @@ struct IteratorHolder { bool operator!=(const IteratorHolder& p) const { return p_->operator!=(*p.p_); } void operator++() { p_->operator++(); } - const TResult& operator*() { return p_->operator*(); } + TResult& operator*() { return p_->operator*(); } T* operator->() { return p_.get(); } private: std::unique_ptr p_; }; +struct TensorProto_ConstIterator { + virtual ~TensorProto_ConstIterator() = default; + virtual bool operator!=(const TensorProto_ConstIterator& p) const = 0; + virtual void operator++() = 0; + virtual const ONNX_NAMESPACE::TensorProto& operator*() const = 0; +}; + +struct TensorProto_Iterator { + virtual ~TensorProto_Iterator() = default; + virtual bool operator!=(const TensorProto_Iterator& p) const = 0; + virtual void operator++() = 0; + virtual ONNX_NAMESPACE::TensorProto& operator*() const = 0; +}; + struct NodeAttributes_Iterator { virtual ~NodeAttributes_Iterator() {} @@ -201,19 +215,6 @@ struct ProviderHost { virtual std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0; virtual std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0; -#ifdef USE_ROCM - virtual std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) = 0; - virtual std::unique_ptr CreateROCMPinnedAllocator(int16_t device_id, const char* name) = 0; - - virtual void rocm__Impl_Cast(void* stream, const int64_t* input_data, int32_t* output_data, size_t count) = 0; - virtual void rocm__Impl_Cast(void* stream, const int32_t* input_data, int64_t* output_data, size_t count) = 0; - virtual void rocm__Impl_Cast(void* stream, const double* input_data, float* output_data, size_t count) = 0; - virtual void rocm__Impl_Cast(void* stream, const float* input_data, double* output_data, size_t count) = 0; - - virtual Status RocmCall_false(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) = 0; - virtual void RocmCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) = 0; -#endif - virtual std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, gsl::span tentative_nodes, @@ -439,7 +440,8 @@ struct ProviderHost { // GraphProto virtual std::unique_ptr GraphProto__construct() = 0; virtual void GraphProto__operator_delete(ONNX_NAMESPACE::GraphProto* p) = 0; - virtual void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) = 0; + virtual ONNX_NAMESPACE::GraphProto& GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) = 0; + virtual ONNX_NAMESPACE::GraphProto& GraphProto__operator_move_assign(ONNX_NAMESPACE::GraphProto* p, ONNX_NAMESPACE::GraphProto&& v) = 0; virtual const ONNX_NAMESPACE::ValueInfoProto& GraphProto__input(const ONNX_NAMESPACE::GraphProto* p, int index) = 0; virtual ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_input(ONNX_NAMESPACE::GraphProto* p) = 0; @@ -492,7 +494,8 @@ struct ProviderHost { // TensorProto virtual std::unique_ptr TensorProto__construct() = 0; virtual void TensorProto__operator_delete(ONNX_NAMESPACE::TensorProto* p) = 0; - virtual void TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) = 0; + virtual ONNX_NAMESPACE::TensorProto& TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) = 0; + virtual ONNX_NAMESPACE::TensorProto& TensorProto__operator_move_assign(ONNX_NAMESPACE::TensorProto* p, ONNX_NAMESPACE::TensorProto&& v) = 0; virtual bool TensorProto__has_name(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual void TensorProto__set_name(ONNX_NAMESPACE::TensorProto* p, const ::std::string& name) = 0; virtual const ::std::string& TensorProto__name(const ONNX_NAMESPACE::TensorProto* p) = 0; @@ -521,8 +524,12 @@ struct ProviderHost { // TensorProtos virtual ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) = 0; - virtual int TensorProtos__size(ONNX_NAMESPACE::TensorProtos* p) = 0; + virtual int TensorProtos__size(const ONNX_NAMESPACE::TensorProtos* p) = 0; virtual ONNX_NAMESPACE::TensorProto& TensorProtos__at(ONNX_NAMESPACE::TensorProtos* p, int index) = 0; + virtual std::unique_ptr TensorProtos__begin(const ONNX_NAMESPACE::TensorProtos* p) = 0; + virtual std::unique_ptr TensorProtos__end(const ONNX_NAMESPACE::TensorProtos* p) = 0; + virtual std::unique_ptr TensorProtos__begin(ONNX_NAMESPACE::TensorProtos* p) = 0; + virtual std::unique_ptr TensorProtos__end(ONNX_NAMESPACE::TensorProtos* p) = 0; // TensorShapeProto_Dimension virtual int TensorShapeProto_Dimension__value_case(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index c7400c276f912..d3584d12df235 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -175,7 +175,8 @@ struct AttributeProto final { struct GraphProto final { static std::unique_ptr Create() { return g_host->GraphProto__construct(); } static void operator delete(void* p) { g_host->GraphProto__operator_delete(reinterpret_cast(p)); } - void operator=(const GraphProto& v) { return g_host->GraphProto__operator_assign(this, v); } + GraphProto& operator=(const GraphProto& v) { return g_host->GraphProto__operator_assign(this, v); } + GraphProto& operator=(GraphProto&& v) noexcept { return g_host->GraphProto__operator_move_assign(this, std::move(v)); } const ValueInfoProto& input(int index) const { return g_host->GraphProto__input(this, index); } ValueInfoProtos* mutable_input() { return g_host->GraphProto__mutable_input(this); } @@ -241,7 +242,10 @@ struct NodeProto final { struct TensorProto final { static std::unique_ptr Create() { return g_host->TensorProto__construct(); } static void operator delete(void* p) { g_host->TensorProto__operator_delete(reinterpret_cast(p)); } - void operator=(const TensorProto& v) { g_host->TensorProto__operator_assign(this, v); } + TensorProto& operator=(const TensorProto& v) { + return g_host->TensorProto__operator_assign(this, v); + } + TensorProto& operator=(TensorProto&& v) noexcept { return g_host->TensorProto__operator_move_assign(this, std::move(v)); } bool has_name() const { return g_host->TensorProto__has_name(this); } void set_name(const ::std::string& name) { return g_host->TensorProto__set_name(this, name); } @@ -283,8 +287,12 @@ struct TensorProto final { struct TensorProtos final { TensorProto* Add() { return g_host->TensorProtos__Add(this); } - int size() { return g_host->TensorProtos__size(this); } + int size() const { return g_host->TensorProtos__size(this); } TensorProto& at(int index) { return g_host->TensorProtos__at(this, index); } + IteratorHolder begin() const { return g_host->TensorProtos__begin(this); } + IteratorHolder end() const { return g_host->TensorProtos__end(this); } + IteratorHolder begin() { return g_host->TensorProtos__begin(this); } + IteratorHolder end() { return g_host->TensorProtos__end(this); } PROVIDER_DISALLOW_ALL(TensorProtos) }; @@ -935,9 +943,9 @@ struct NodeAttributes final { ONNX_NAMESPACE::AttributeProto& operator[](const std::string& string) { return g_host->NodeAttributes__operator_array(this, string); } const ONNX_NAMESPACE::AttributeProto& at(const std::string& string) const { return g_host->NodeAttributes__at(this, string); } - IteratorHolder> begin() const { return g_host->NodeAttributes__begin(this); } - IteratorHolder> end() const { return g_host->NodeAttributes__end(this); } - IteratorHolder> find(const std::string& key) const { return g_host->NodeAttributes__find(this, key); } + IteratorHolder> begin() const { return g_host->NodeAttributes__begin(this); } + IteratorHolder> end() const { return g_host->NodeAttributes__end(this); } + IteratorHolder> find(const std::string& key) const { return g_host->NodeAttributes__find(this, key); } void insert(const NodeAttributes& v) { return g_host->NodeAttributes__insert(this, v); } void emplace(const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) { g_host->NodeAttributes__emplace(this, k, v); } void emplace(const std::string& k, ONNX_NAMESPACE::AttributeProto&& v) { g_host->NodeAttributes__emplace(this, k, std::move(v)); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index cd0c0e4bffdb5..ce7285e3b114a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2035,9 +2035,30 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs; + + // These maps store the inputs and outputs of the subgraph. + // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration + // to determine the final inputs and outputs of the subgraph. + std::unordered_map fused_inputs, fused_outputs; + + // This map stores the node's output that will be consumed by another node outside of this subgraph. + // So the node's output should be put into the subgraph's output list. + std::unordered_map fused_outputs_to_add; + + // This map stores the node's output that is original graph's output. + // So the node's output should be put into the subgraph's output list. + std::unordered_map graph_outputs_to_add; + std::unordered_set erased; + + // This is the relative ordering that ensures node's input or output being added to the 'fused_inputs', + // 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. + // Items added earlier receive a smaller order index than items added later. + // When constructing the final sub_graph's input or output lists, entries with smaller + // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; @@ -2056,7 +2077,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -2071,7 +2092,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -2092,39 +2113,33 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } else { output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; } - if (node_set.find(node_idx) != node_set.end()) { - const auto& iter = fused_inputs.find(output); - if (iter != fused_inputs.end()) { - fused_inputs.erase(iter); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } - fused_outputs[output] = output_order++; - } - } else { - fused_outputs_to_add[output] = output_order++; + + if (node_set.find(node_idx) == node_set.end()) { + // This output will be consumed by another node outside of this subgraph. + // So the output should be put into the subgraph's output list. + fused_outputs_to_add.insert({output, output_order++}); } } - } else { - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } - // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list - else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } + } - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - fused_outputs[output] = output_order++; - } + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + // Only when output is neither in input list nor erased list, + // and the output is consumed by another node, add the output to output list + fused_outputs.insert({output, output_order++}); } } + + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order++}); + } } } @@ -2280,7 +2295,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect SetAllGraphInputs(graph_build); } - ORT_ENFORCE(graph_build.Resolve().IsOK()); + ORT_THROW_IF_ERROR(graph_build.Resolve()); // Add parent graph output to the subgraph int i = 0; @@ -2295,7 +2310,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect auto& graph_build_outputs = graph_build.GetOutputs(); subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end()); graph_build.SetOutputs(graph_build_outputs); - ORT_ENFORCE(graph_build.Resolve().IsOK()); + ORT_THROW_IF_ERROR(graph_build.Resolve()); // Check if input tensors have shapes if (iterations > 1) { @@ -2332,27 +2347,25 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating // the model proto that has different node ordering compared to original onnx model. - // Save Initializer Data. - std::vector userWeights; + auto graph_proto = ONNX_NAMESPACE::GraphProto::Create(); + graph_viewer->ToProto(*graph_proto, true, true, 1 /*priority-based topological sort*/, !load_user_initializer_ /*include_initializer_data*/); - // Keep inits in memory instead of writing to ModelProto. + // Save Initializer Data. + std::vector userWeights; if (load_user_initializer_) { - auto allInitializers = graph_viewer->GetAllInitializedTensors(); - - for (auto& entry : allInitializers) { - auto* tp = entry.second; + const auto& allInitializers = graph_viewer->GetAllInitializedTensors(); + for (const auto& [name, tp] : allInitializers) { if (tp->has_raw_data()) { - userWeights.emplace_back(tp->name(), tp->raw_data()); + userWeights.emplace_back(name, tp->raw_data()); } else if (utils::HasExternalDataInMemory(*tp)) { std::unique_ptr full_init; ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); - userWeights.emplace_back(full_init->name(), full_init->raw_data()); + userWeights.emplace_back(name, full_init->raw_data()); } } } - - graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/, !load_user_initializer_ /*include_initializer_data*/); + *model_proto->mutable_graph() = std::move(*graph_proto); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; @@ -3098,22 +3111,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView auto model = graph_body_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); + // Note, wrapping std::vector into a smart ptr is redundant as the vector is a smart ptr in a sense. auto userWeights = std::make_unique>(); - if (load_user_initializer_) { - auto allInitializers = graph_body_viewer.GetAllInitializedTensors(); - - for (auto& entry : allInitializers) { - auto name = entry.first; - auto* tp = entry.second; - if (tp->has_raw_data()) { - userWeights->emplace_back( - TensorrtUserWeights(tp->name(), tp->raw_data())); + const auto& allInitializers = graph_body_viewer.GetAllInitializedTensors(); + for (const auto& [name, tp] : allInitializers) { + if (utils::HasRawData(*tp)) { + userWeights->emplace_back(name, tp->raw_data()); } else if (utils::HasExternalDataInMemory(*tp)) { std::unique_ptr full_init; ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); - userWeights->emplace_back( - TensorrtUserWeights(full_init->name(), full_init->raw_data())); + userWeights->emplace_back(name, full_init->raw_data()); } } } diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 5866dd3e83624..4905df2a71867 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -611,3 +611,53 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return &the_global_api; } } + +struct ExternalEpLibaray { + ExternalEpLibaray(const std::string& libray_name) : libray_name_{libray_name} { + Ensure(); + } + onnxruntime::Provider* (*get_provider_api)(); + void (*create_ep_factories)(void*, const OrtApiBase*, void*, OrtEpFactory**, size_t, size_t*); + void (*set_session_option)(OrtSessionOptions*); + + void Ensure() { + if (handle_) + return; + auto& env = Provider_GetHost()->Env__Default(); + auto library_filename = PathString(LIBRARY_PREFIX) + PathString(libray_name_.begin(), libray_name_.end()) + LIBRARY_EXTENSION; + auto full_path = env.GetRuntimePath() + library_filename; + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_)); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "GetProvider", (void**)&get_provider_api)); + } + + void Clear() { + if (handle_) { + auto& env = Provider_GetHost()->Env__Default(); + auto status = env.UnloadDynamicLibrary(handle_); + vai_assert(status.IsOK(), status.ErrorMessage()); + handle_ = nullptr; + } + } + + private: + std::string libray_name_; + void* handle_{}; +}; +static std::unordered_map> g_external_ep_libaries; + +std::unique_ptr +CreateExecutionProviderFromAnotherEp(const std::string& lib, const OrtSessionOptions& session_options, + std::unordered_map& provider_options) { + auto it = g_external_ep_libaries.find(lib); + if (it == g_external_ep_libaries.end()) { + it = g_external_ep_libaries.emplace(lib, std::make_unique(lib)).first; + } + auto ep_lib = it->second.get(); + auto get_provider_func = ep_lib->get_provider_api; + auto provider = get_provider_func(); + std::unique_ptr ret; + provider->Initialize(); + std::ignore = provider->CreateIExecutionProvider(nullptr, nullptr, 0, const_cast(provider_options), session_options, *((OrtLogger*)nullptr), ret); + + return ret; +} \ No newline at end of file diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 7791ea430054a..567f2cb4b39e3 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -6,10 +6,12 @@ #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/provider_options.h" +#include "core/framework/execution_provider.h" #include "vaip/my_ort.h" #include "vaip/dll_safe.h" #include "vaip/custom_op.h" #include +#include void initialize_vitisai_ep(); void deinitialize_vitisai_ep(); vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); @@ -40,3 +42,6 @@ using EventInfo = std::tuple< void profiler_collect( std::vector& api_events, std::vector& kernel_events); +std::unique_ptr +CreateExecutionProviderFromAnotherEp(const std::string& lib, const OrtSessionOptions& session_options, + std::unordered_map& provider_options); diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 50f924e468ed0..e1a3ca43e162e 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -7,7 +7,6 @@ #include #include #include - #include "vaip/global_api.h" #include "./vitisai_execution_provider.h" #include "core/framework/execution_provider.h" @@ -57,6 +56,10 @@ std::unique_ptr VitisAIProviderFactory::CreateProvider(const } } + auto it = provider_options.find("external_ep_libray"); + if (it != provider_options.end()) { + return CreateExecutionProviderFromAnotherEp(it->second, session_options, provider_options); + } auto ep_instance = std::make_unique(provider_options); ep_instance->SetLogger(reinterpret_cast(&session_logger)); return ep_instance; diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc index 85096d0e262d7..9948069c6779b 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc @@ -78,8 +78,8 @@ bool ClipOpBuilder::HandleBuildOp(vsi::npu::GraphEP* graph_ep, LOGS_DEFAULT(INFO) << "Creating Clip Op."; if (node_unit.SinceVersion() <= 6) { NodeAttrHelper helper(node_unit.GetNode()); - auto min = helper.Get("min", -3.402e+38f); - auto max = helper.Get("max", 3.402e+38f); + auto min = helper.Get("min", -3.4028234663852886e+38f); + auto max = helper.Get("max", 3.4028234663852886e+38f); auto op = graph_ep->GetGraph()->CreateOperation(min, max); (*op).BindInputs(inputs).BindOutputs(outputs); graph_ep->GetOps().push_back(std::move(op)); diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index b3eb4b5061423..3e1b87821fe2f 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -13,7 +13,7 @@ GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool OrtMemoryInfo(WEBGPU_BUFFER, is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator : OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0), + WebGpuDevice, OrtMemTypeDefault)), buffer_manager_{buffer_manager}, mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} { diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 7c38b4557e078..74b3d669fcf3b 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -11,6 +11,11 @@ namespace webgpu { class BufferManager; +inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU, + OrtDevice::MemType::DEFAULT, + OrtDevice::VendorIds::NONE, + 0}; + class GpuBufferAllocator : public IAllocator { public: GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator); diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index ebe71c6ccfacd..d1a2011c8e191 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -6,22 +6,25 @@ namespace onnxruntime { namespace webgpu { -ComputeContext::ComputeContext(OpKernelContext& kernel_context, - const OpKernel& op_kernel, - const WebGpuExecutionProvider& ep, - WebGpuContext& webgpu_context) + +ComputeContextBase::ComputeContextBase(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel) : webgpu_context_{webgpu_context}, - kernel_context_{kernel_context}, - op_kernel_{op_kernel}, - ep_{ep} { + ep_{ep}, + op_kernel_{op_kernel} { } -const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const ComputeContext& context) { +const webgpu::BufferManager& ComputeContextBase::BufferManagerAccessor::Get(const ComputeContextBase& context) { return context.ep_.BufferManager(); } -const SplitKConfig& ComputeContext::GetSplitKConfig() { - return webgpu_context_.GetSplitKConfig(); +ComputeContext::ComputeContext(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel, + OpKernelContext& kernel_context) + : ComputeContextBase(webgpu_context, ep, op_kernel), + kernel_context_{kernel_context} { } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index ed16f2f0a1345..fdf89854469d6 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -24,7 +24,13 @@ namespace webgpu { class WebGpuContext; class BufferManager; -class ComputeContext final { +// +// Class ComputeContextBase is designed to provide basic context information +// for running a compute shader program. +// +// An instance of ComputeContextBase does not depend on OpKernelContext, which needs an execution frame to be created. +// +class ComputeContextBase { public: // Nested accessor class to provide controlled access to BufferManager class BufferManagerAccessor { @@ -34,18 +40,31 @@ class ComputeContext final { friend class WebGpuContext; private: - static const webgpu::BufferManager& Get(const ComputeContext& context); + static const webgpu::BufferManager& Get(const ComputeContextBase& context); }; - ComputeContext(OpKernelContext& kernel_context, - const OpKernel& op_kernel, - const WebGpuExecutionProvider& ep, - WebGpuContext& webgpu_context); + ComputeContextBase(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel); - ~ComputeContext() = default; + ~ComputeContextBase() = default; + + // + // Get the node name. + // + inline decltype(auto) NodeName() const { + return op_kernel_.Node().Name(); + } + + // + // Get the operator type. + // + inline decltype(auto) OpType() const { + return op_kernel_.Node().OpType(); + } // - // Get various information from the context. + // Get various information from the WebGPU context. // inline const wgpu::AdapterInfo& AdapterInfo() const { @@ -57,9 +76,6 @@ class ComputeContext final { inline bool HasFeature(wgpu::FeatureName feature) const { return webgpu_context_.DeviceHasFeature(feature); } - inline bool IsGraphCaptureEnabled() const { - return ep_.IsGraphCaptureEnabled(); - } #if !defined(__wasm__) inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const { return webgpu_context_.SubgroupMatrixConfigs(); @@ -67,17 +83,57 @@ class ComputeContext final { #endif // - // Get the kernel context. + // Get Split-K configuration. // - inline OpKernelContext& KernelContext() { - return kernel_context_; + inline const SplitKConfig& GetSplitKConfig() const { + return webgpu_context_.GetSplitKConfig(); + } + + // + // Get whether graph capture is enabled. + // + inline bool IsGraphCaptureEnabled() const { + return ep_.IsGraphCaptureEnabled(); } // // Get the logger. // inline const logging::Logger& Logger() const { - return kernel_context_.Logger(); + return *ep_.GetLogger(); + } + + // + // Run a compute shader program. + // + inline Status RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); + } + + protected: + WebGpuContext& webgpu_context_; + const WebGpuExecutionProvider& ep_; + const OpKernel& op_kernel_; +}; + +// +// Class ComputeContext provides all information a `ComputeContextBase` provides, and also +// access to `OpKernelContext` for input and output tensors. +// +class ComputeContext final : public ComputeContextBase { + public: + ComputeContext(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel, + OpKernelContext& kernel_context); + + ~ComputeContext() = default; + + // + // Get the kernel context. + // + inline OpKernelContext& KernelContext() { + return kernel_context_; } // @@ -145,25 +201,8 @@ class ComputeContext final { return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst); } - // - // Run a compute shader program. - // - inline Status RunProgram(const ProgramBase& program) { - return webgpu_context_.Run(*this, program); - } - - // - // Get Split-K configuration. - // - // `split_k_config_` won't be initialized until the first call to this method. - // - const SplitKConfig& GetSplitKConfig(); - private: - WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; - const OpKernel& op_kernel_; - const WebGpuExecutionProvider& ep_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 82645e30082e6..3c974ef5133c0 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -322,11 +322,14 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { round_str = "round"; } - std::string use_sqrt_for_pow; + std::string use_pow_shortcut; if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + // use multiplication instead of pow when base (a) is a float and exponent (b) is 2.0 // use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5 - use_sqrt_for_pow = - " else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" + use_pow_shortcut = + " else if (b == 2.0) {\n" + " return a * a;\n" + " } else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" " return sqrt(a);\n" " }\n"; } @@ -337,7 +340,7 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { " } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n" " return input_a_element_t(pow(f32(a), b)); // NaN\n" " }\n" - << use_sqrt_for_pow + << use_pow_shortcut << " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n" << "}\n" "fn pow_v(a : vec4, b : vec4) -> vec4 {\n" diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 6aefa90a59285..b81977883dd70 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -30,7 +30,12 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const { } else { ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_)); } - MatMulWriteFnSource(shader, output, need_handle_bias_, true, c_components_, output_components_, c_is_scalar_); + + const ShaderVariableHelper* c = nullptr; + if (need_handle_bias_) { + c = &shader.AddInput("c", ShaderUsage::UseUniform); + } + MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, output_components_, c_is_scalar_); return Status::OK(); } @@ -93,18 +98,21 @@ Status ApplyGemmPacked(const Tensor* a, } const uint32_t TILE_SIZE = 32; - const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE; - const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE; program.CacheHint(alpha, transA, transB, c_is_scalar) .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) - .SetDispatchGroupSize(num_tile_n, num_tile_m, 1) + .SetDispatchGroupSize(dispatch_x, dispatch_y, 1u) .SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z) .AddUniformVariables({{alpha}, {beta}, - {M}, /* dim_a_outer */ - {N}, /* dim_b_outer */ - {K}} /*dim_inner */ + {M}, /* dim_a_outer */ + {N}, /* dim_b_outer */ + {K}, /*dim_inner */ + {dispatch_x}, /* logical_dispatch_x */ + {dispatch_y}, /* logical_dispatch_y */ + {1u}} /* logical_dispatch_z */ ); return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.h b/onnxruntime/core/providers/webgpu/math/gemm_packed.h index dce5164693aa8..cb89ccefba313 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.h @@ -32,7 +32,10 @@ class GemmProgram final : public Program { {"beta", ProgramUniformVariableDataType::Float32}, {"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8; constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 7cbc7f6a4a821..73242ed3ff1ba 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -15,26 +15,25 @@ namespace { void HandleMaybeHaveBiasForGEMM(ShaderHelper& shader, const ShaderVariableHelper& output, - bool has_bias, + const ShaderVariableHelper* c, int c_components, int output_components, bool c_is_scalar) { shader.AdditionalImplementation() << " let coords = vec2(u32(row), u32(colIn));\n"; - if (has_bias) { - const ShaderVariableHelper& C = shader.AddInput("C", ShaderUsage::UseUniform); + if (c != nullptr) { shader.AdditionalImplementation() << " value += output_element_t(uniforms.beta) * "; // We can be allowed to use broadcasting only when both components are equal. // There is only one case for c_components_ is not equal output_components. // I.g. the former is `1` and the latter is `4`. - // That means the shape of C is either {M,1} or {1,1} + // That means the shape of `c` is either {M,1} or {1,1} if (c_components == output_components) { shader.AdditionalImplementation() << "output_value_t(" - << C.GetByOffset(C.BroadcastedIndicesToOffset("vec2(u32(row), u32(colIn))", output)) << ");\n"; + << c->GetByOffset(c->BroadcastedIndicesToOffset("vec2(u32(row), u32(colIn))", output)) << ");\n"; } else if (c_is_scalar) { - shader.AdditionalImplementation() << "output_value_t(C[0]);\n"; + shader.AdditionalImplementation() << "output_value_t(" << c->GetByOffset("0") << ");\n"; } else { - shader.AdditionalImplementation() << "output_value_t(C[row]);\n"; + shader.AdditionalImplementation() << "output_value_t(" << c->GetByOffset("row") << ");\n"; } } shader.AdditionalImplementation() << output.SetByIndices("coords", "value") << "\n"; @@ -42,12 +41,12 @@ void HandleMaybeHaveBiasForGEMM(ShaderHelper& shader, void HandleMaybeBiasForMatMul(ShaderHelper& shader, const ShaderVariableHelper& output, - bool has_bias, + const ShaderVariableHelper* bias, std::string activation_snippet, bool is_channels_last) { shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n"; - if (has_bias) { - shader.AdditionalImplementation() << " value = value + output_value_t(" << (is_channels_last ? "bias[colIn]" : "bias[row]") << ");\n"; + if (bias != nullptr) { + shader.AdditionalImplementation() << " value = value + output_value_t(" << (is_channels_last ? bias->GetByOffset("colIn") : bias->GetByOffset("row")) << ");\n"; } shader.AdditionalImplementation() << " " << activation_snippet << "\n" << output.SetByIndices("coords", "value") << "\n"; @@ -55,6 +54,7 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader, void HandleMatMulWithSplitK( ShaderHelper& shader, + const ShaderVariableHelper& output, ProgramVariableDataType output_variable_type) { shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n"; @@ -76,16 +76,8 @@ void HandleMatMulWithSplitK( let offset0 = i2o_output(coords) * 4u; for (var i = 0u; i < 4u; i++) { let offset = offset0 + i; - while (true) { - let old_output_i32 = atomicLoad(&output[offset]); - let old_output_f32 = bitcast(old_output_i32); - let new_output_f32 = old_output_f32 + value[i]; - let new_output_i32 = bitcast(new_output_f32); - let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32); - if (output_compare_exchange.old_value == old_output_i32) { - break; - } - } +)"; + shader.AdditionalImplementation() << GenerateAtomicAddNonIntegerCode(output, "offset", "f32", "value[i]") << R"( } )"; break; @@ -98,16 +90,8 @@ void HandleMatMulWithSplitK( vec2h_values[1] = value.zw; for (var i = 0u; i < 2u; i++) { let offset = offset0 + i; - while (true) { - let old_output_i32 = atomicLoad(&output[offset]); - let old_output_vec2h = bitcast(old_output_i32); - let new_output_vec2h = old_output_vec2h + vec2h_values[i]; - let new_output_i32 = bitcast(new_output_vec2h); - let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32); - if (output_compare_exchange.old_value == old_output_i32) { - break; - } - } +)"; + shader.AdditionalImplementation() << GenerateAtomicAddNonIntegerCode(output, "offset", "vec2h", "vec2h_values[i]") << R"( } )"; break; @@ -117,6 +101,20 @@ void HandleMatMulWithSplitK( } } +// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in +// `ProgramBase.SetDispatchGroupSize()` may be normalized in +// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use +// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. +void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) { + shader.MainFunctionBody() + << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n" + << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" + << " const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n" + << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n"; +} + } // namespace void MatMulReadFnSource(ShaderHelper& shader, @@ -183,7 +181,7 @@ void MatMulReadFnSource(ShaderHelper& shader, void MatMulWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& output, - bool has_bias, + const ShaderVariableHelper* bias, bool is_gemm, int c_components, int output_components, @@ -203,16 +201,16 @@ void MatMulWriteFnSource(ShaderHelper& shader, if (use_split_k) { // Set output when MatMul is performed with Split-K. // When Split-K is used in MatMul, the bias will be handled in `MatMulFillBiasOrZeroBeforeSplitKProgram` - // instead of here, so `has_bias` and `is_channels_last` is not used for Split-K. Note that we - // still need to handle `has_bias` (and `is_channels_last` in the future) in + // instead of here, so `bias` and `is_channels_last` is not used for Split-K. Note that we + // still need to handle `bias` (and `is_channels_last` in the future) in // `MatMulFillBiasOrZeroBeforeSplitKProgram`. - ORT_ENFORCE(!has_bias, "Bias is not supported in MatMulProgram when Split-K is enabled."); + ORT_ENFORCE(bias == nullptr, "Bias is not supported in MatMulProgram when Split-K is enabled."); ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled."); - HandleMatMulWithSplitK(shader, output_variable_type); + HandleMatMulWithSplitK(shader, output, output_variable_type); } else if (is_gemm) { - HandleMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar); + HandleMaybeHaveBiasForGEMM(shader, output, bias, c_components, output_components, c_is_scalar); } else { - HandleMaybeBiasForMatMul(shader, output, has_bias, activation_snippet, is_channels_last); + HandleMaybeBiasForMatMul(shader, output, bias, activation_snippet, is_channels_last); } shader.AdditionalImplementation() @@ -274,20 +272,22 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << "const innerElementSize = " << inner_elements_size << ";\n" << "const tileInner = " << tile_inner << ";\n"; + InitializeLogicalWorkgroupIDAndGlobalID(shader); + shader.MainFunctionBody() << " let localRow = i32(local_id.y);\n" << " let tileRow = localRow * rowPerThread;\n" << " let tileCol = i32(local_id.x);\n" - << " let globalRow = i32(global_id.y) * rowPerThread;\n" - << " let globalCol = i32(global_id.x);\n" - << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" - << " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" + << " let globalRow = i32(logical_global_id.y) * rowPerThread;\n" + << " let globalCol = i32(logical_global_id.x);\n" + << " let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" + << " let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" << " var acc: array, rowPerThread>;\n"; if (split_k) { // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into // multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from - // `kSplitK * i32(global_id.z)`. + // `kSplitK * i32(logical_global_id.z)`. // // For example: considering computing Y = (X * W + B) in one workgroup. // Let kSplitK = 2, B = [d1, d2] @@ -305,15 +305,15 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2) // Workgroup3: compute (C1 * C2) // In each workgroup: - // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z` + // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `logical_global_id.z` // - When the computation in each workgroup is completed, add the result to Y with several // atomic built-in functions in `HandleMatMulWithSplitK()`. shader.MainFunctionBody() << "const kSplitK = " << split_dim_inner << ";\n" << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n" - << " var kStart = kSplitK * i32(global_id.z);\n" + << " var kStart = kSplitK * i32(logical_global_id.z);\n" - // When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate + // When Split-K is used, `batch` should always be 0 and `logical_global_id.z` is used to indicate // the index of split-k instead of batch. << " let batch = 0;\n" << " let batchIndices = 0u;\n"; @@ -321,7 +321,7 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, shader.MainFunctionBody() << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" - << " let batch = i32(global_id.z);\n" + << " let batch = i32(logical_global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : ""); } @@ -498,7 +498,9 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, << "const colPerThread = " << elements_per_thread_x << ";\n" << "const tileInner = " << tile_inner << ";\n"; - shader.MainFunctionBody() << " let batch = i32(global_id.z);\n" + InitializeLogicalWorkgroupIDAndGlobalID(shader); + + shader.MainFunctionBody() << " let batch = i32(logical_global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" @@ -507,10 +509,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, shader.MainFunctionBody() << "let tileRow = i32(local_id.y) * rowPerThread;\n" << "let tileCol = i32(local_id.x) * colPerThread;\n" - << "let globalRow = i32(global_id.y) * rowPerThread;\n" - << "let globalCol = i32(global_id.x) * colPerThread;\n" - << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" - << "let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" + << "let globalRow = i32(logical_global_id.y) * rowPerThread;\n" + << "let globalCol = i32(logical_global_id.x) * colPerThread;\n" + << "let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" + << "let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n" << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n" << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.h b/onnxruntime/core/providers/webgpu/math/gemm_utils.h index 7075debeb9952..e001544f9e50d 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.h @@ -18,7 +18,7 @@ void MatMulReadFnSource(ShaderHelper& shader, void MatMulWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& output, - bool has_bias, + const ShaderVariableHelper* bias, bool is_gemm, int c_components, int output_components, diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 55c2c5773cc1f..72dd235eb820a 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -256,8 +256,6 @@ Status ComputeMatMul(ComputeContext* context, // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the // number of splits along `dim_inner`. - // TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize - // the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`. split_dim_inner = split_k_config.GetSplitDimInner(); dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner; @@ -271,7 +269,7 @@ Status ComputeMatMul(ComputeContext* context, .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) - .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}}) .AddIndices(outer_dims) .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z) diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index 4daabe8246aa7..80a110c3b505c 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -26,14 +26,15 @@ Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const ShaderVariableHelper* bias = nullptr; if (has_bias_) { - shader.AddInput("bias", ShaderUsage::UseUniform); + bias = &shader.AddInput("bias", ShaderUsage::UseUniform); } std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; // declare the read and write functions MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false, is_vec4_); - MatMulWriteFnSource(shader, output, has_bias_, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type); + MatMulWriteFnSource(shader, output, bias, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type); std::string data_type = "a_element_t"; // generate the main function if (is_vec4_) { @@ -54,8 +55,9 @@ bool MatMulProgram::NeedSplitK() const { Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const ShaderVariableHelper* bias = nullptr; if (has_bias_) { - shader.AddInput("bias", ShaderUsage::UseUniform); + bias = &shader.AddInput("bias", ShaderUsage::UseUniform); } // Handle bias with `MatMulWriteFnSource()`. @@ -63,7 +65,7 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& // `use_split_k` is true only when we do the actual MatMul with Split-K. // Currently we only support bias in vec4 and channels last format for Split-K MatMul. MatMulWriteFnSource( - shader, output, has_bias_, /*is_gemm*/ false, /*c_components*/ 4, /*output_components*/ 4, /*c_is_scalar*/ false, + shader, output, bias, /*is_gemm*/ false, /*c_components*/ 4, /*output_components*/ 4, /*c_is_scalar*/ false, /*activation_snippet*/ "", /*is_channels_last*/ true, /*use_split_k*/ false); shader.MainFunctionBody() << R"( diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index 143ba61c99e13..dbd193bc38f58 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -24,7 +24,10 @@ class MatMulProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); bool NeedSplitK() const; diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc index 2f34aa21c8309..bf3bb53341418 100644 --- a/onnxruntime/core/providers/webgpu/math/softmax.cc +++ b/onnxruntime/core/providers/webgpu/math/softmax.cc @@ -64,7 +64,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { int components = input.NumComponents(); const std::string thread_max_decl = is_fp32_ - ? "var thread_max = x_value_t(-3.402823e+38f);\n" + ? "var thread_max = x_value_t(-3.4028234663852886e+38f);\n" : "var thread_max = x_value_t(-65504.0h);\n"; // Define shared memory for row max and row sum diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index e5fca91a53bf8..3285f1e6065bb 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -108,7 +108,7 @@ fn erf_v(v: x_value_t) -> x_value_t { constexpr const char HardSigmoidImpl[] = R"( fn hard_sigmoid_v(v: vec4) -> vec4 { let alpha = x_element_t(uniforms.attr[0]); - let beta_v = vec4(uniforms.attr[1]); + let beta_v = vec4(x_element_t(uniforms.attr[1])); return max(vec4(0.0), min(vec4(1.0), alpha * v + beta_v)); } diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 77fa46cb87518..5ff18235a706f 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/webgpu/nn/conv.h" #include "core/providers/webgpu/nn/conv2d_mm.h" +#include "core/providers/webgpu/nn/im2col_matmul.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/tensor/transpose.h" @@ -99,10 +100,34 @@ Status Conv::ComputeInternal(ComputeContext& context modified_input_output_shapes.push_back(bias->Shape()); } modified_input_output_shapes.push_back(TensorShape(output_shape_vector)); + + const auto input_height = input_shape[is_channels_last ? 1 : 2]; + const auto input_width = input_shape[is_channels_last ? 2 : 3]; + const auto input_channels = input_shape[is_channels_last ? 3 : 1]; + const auto kernel_height = kernel_shape[2]; + const auto kernel_width = kernel_shape[3]; + const auto output_height = output_shape_vector[is_channels_last ? 1 : 2]; + const auto output_width = output_shape_vector[is_channels_last ? 2 : 3]; + uint32_t auto_pad_adjust = conv_attrs_.auto_pad == AutoPadType::SAME_LOWER ? 1 : 0; auto pad0 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[0] : (pads[0] + pads[2] + auto_pad_adjust) / 2; auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2; std::vector updated_pads{pad0, pad1}; + + if (CanApplyIm2ColMatMulProgram(context, + is_channels_last, + activation_.activation_kind_, + kernel_shape, + conv_attrs_.auto_pad, + onnxruntime::narrow(conv_attrs_.group))) { + return ApplyIm2ColMatMulProgram(context, + is_channels_last, + dilations, + pads, + strides, + output); + } + if (conv_attrs_.group > 1) { Tensor transposed_kernel; if (is_channels_last) { @@ -128,13 +153,6 @@ Status Conv::ComputeInternal(ComputeContext& context } return context.RunProgram(program); } - const auto input_height = input_shape[is_channels_last ? 1 : 2]; - const auto input_width = input_shape[is_channels_last ? 2 : 3]; - const auto input_channels = input_shape[is_channels_last ? 3 : 1]; - const auto kernel_height = kernel_shape[2]; - const auto kernel_width = kernel_shape[3]; - const auto output_height = output_shape_vector[is_channels_last ? 1 : 2]; - const auto output_width = output_shape_vector[is_channels_last ? 2 : 3]; const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0; if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) { @@ -216,6 +234,46 @@ Status Conv::ComputeInternal(ComputeContext& context return context.RunProgram(conv2d_mm_program); } +template +Status Conv::PrePackInternal(ComputeContextBase& /* context */, + const Tensor& tensor, + int input_idx, + AllocatorPtr /* alloc */, + /*out*/ bool& is_packed) { + is_packed = false; + + if constexpr (is_channels_last) { + if (input_idx == 1 && tensor.Shape().NumDimensions() == 4) { + // only deal with 4D NHWC weights + + // TODO: implement weight transpose for pre-pack here + // Conv::ComputeInternal() should be updated to reflect the change: + // - if the initializer is packed, `context.Input(1)` will be nullptr. + // - in this case, use `transposed_kernel_` instead. + + // // Step.1 - calculate transposed weight shape + // TensorShape transposed_kernel_shape{tensor.Shape()[2], + // tensor.Shape()[3], + // tensor.Shape()[1], + // tensor.Shape()[0]}; + + // // Step.2 - create transposed weight tensor + // transposed_kernel_ = std::make_unique(tensor.DataType(), transposed_kernel_shape, alloc); + + // // Step.3 - do transpose + // size_t perm[] = {2, 3, 1, 0}; + // ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, + // perm, + // tensor, + // *transposed_kernel_)); + + // is_packed = true; // set this flag to true so that ORT will release the initializer tensor + } + } + + return Status::OK(); +} + // Explicit template instantiation for FusedConv template class Conv; template class Conv; diff --git a/onnxruntime/core/providers/webgpu/nn/conv.h b/onnxruntime/core/providers/webgpu/nn/conv.h index cafaa272c0613..5bf94a459a44a 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.h +++ b/onnxruntime/core/providers/webgpu/nn/conv.h @@ -23,9 +23,16 @@ class Conv : public WebGpuKernel { } Status ComputeInternal(ComputeContext& context) const override; + Status PrePackInternal(ComputeContextBase& context, + const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed) override; + protected: ConvAttributes conv_attrs_; Activation activation_; + std::unique_ptr transposed_kernel_; // should only have value when `is_initializer` AND `is_4D` AND `is_NHWC` }; Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm); diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc index 2d5424c52a3f2..c66f2cbd582d9 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc @@ -226,7 +226,10 @@ Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::v {static_cast(dim_inner)}, {pads}, {strides}, - {dilations}}); + {dilations}, + {dispatch[0]}, + {dispatch[1]}, + {dispatch[2]}}); return program; } diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h index d7cc08aae26f3..e161bffb0c503 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h @@ -38,7 +38,10 @@ class Conv2dMMProgram final : public Program { {"dim_inner", ProgramUniformVariableDataType::Uint32}, {"pads", ProgramUniformVariableDataType::Uint32}, {"strides", ProgramUniformVariableDataType::Uint32}, - {"dilations", ProgramUniformVariableDataType::Uint32}); + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); private: const Activation& activation_; diff --git a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc new file mode 100644 index 0000000000000..685324884abeb --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc @@ -0,0 +1,233 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include +#include + +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/nn/im2col_matmul.h" +#include "core/providers/webgpu/nn/activation_util.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { + +// TODO: move to common header. +template +inline T ceil_div(T numerator, T denominator) { + return (numerator + denominator - 1) / denominator; +} + +// Chooses the optimal tile size (M, N) for the im2col operation. +// This tile size is performance-tuned and varies depending on the target device. +std::pair ChooseTileSize(uint32_t im2col_m, uint32_t im2col_n) { + // Define a list of preferred (tile_m, tile_n) pairs in descending order of preference. + const std::vector> kTileSizes = { + std::make_pair(32, 64), + std::make_pair(16, 64), + }; + + for (const auto& tile_pair : kTileSizes) { + const uint32_t tile_m = tile_pair.first; + const uint32_t tile_n = tile_pair.second; + + const uint32_t dispatch_m = ceil_div(im2col_m, tile_m); + const uint32_t dispatch_n = ceil_div(im2col_n, tile_n); + const uint32_t dispatch = dispatch_m * dispatch_n; + + if (dispatch >= 128) { + return tile_pair; + } + } + + // If none of the tile sizes meet the dispatch >=128 requirement, + return kTileSizes.back(); +} + +// Add support for more devices. +bool IsDeviceSupported(ComputeContext& context) { + const wgpu::AdapterInfo& adapter_info = context.AdapterInfo(); + + if (adapter_info.vendor == std::string_view("intel")) { + if (adapter_info.architecture == std::string_view("xe-2lpg")) { + return true; + } + } + + return false; +} + +} // namespace + +Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template", + WGSL_TEMPLATE_VARIABLE(output, output), + WGSL_TEMPLATE_VARIABLE(src, src)); +} + +Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + ORT_ENFORCE(tile_m_ == 16 || tile_m_ == 32, "tile_m must be 16 or 32."); + ORT_ENFORCE(tile_n_ == 64, "tile_n must be 64."); + + return WGSL_TEMPLATE_APPLY(shader, "nn/im2col_matmul.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), + WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_), + WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_), + WGSL_TEMPLATE_PARAMETER(use_subgroup, use_subgroup_), + WGSL_TEMPLATE_VARIABLE(output, output), + WGSL_TEMPLATE_VARIABLE(src, src), + WGSL_TEMPLATE_VARIABLE(weight, weight)); +} + +Status ApplyIm2ColMatMulProgram(ComputeContext& context, + bool is_channels_last, + const std::vector& dilations, + const std::vector& pads, + const std::vector& strides, + Tensor* output) { + const auto* src = context.Input(0); + const auto* weight = context.Input(1); + const bool has_bias = context.InputCount() > 2; + const auto* bias = has_bias ? context.Input(2) : nullptr; + + // Transpose OIHW Weight to OHWI + // TODO: Move to `Transpose` + // TODO: Use prepack + TensorShape weight_shape = weight->Shape(); + const uint32_t channel_output = onnxruntime::narrow(weight_shape[0]); + const uint32_t channel_input = onnxruntime::narrow(weight_shape[1]); + const uint32_t kernel_height = onnxruntime::narrow(weight_shape[2]); + const uint32_t kernel_width = onnxruntime::narrow(weight_shape[3]); + + TensorShape ohwi_weight_shape{channel_output, kernel_height, kernel_width, channel_input}; + Tensor ohwi_weight = context.CreateGPUTensor(weight->DataType(), ohwi_weight_shape); + OIHW2OHWIProgram transpose_program{}; + transpose_program.SetWorkgroupSize(64); + + const uint32_t Ci_tiles = ceil_div(channel_input, 64u); + transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles); + + transpose_program.AddInput({weight, + ProgramTensorMetadataDependency::TypeAndRank}); + transpose_program.AddOutput({&ohwi_weight, + ProgramTensorMetadataDependency::TypeAndRank}); + transpose_program.AddUniformVariables({{channel_output}, + {channel_input}, + {kernel_height}, + {kernel_width}, + {Ci_tiles}, + {ceil_div(kernel_height * kernel_height, 4u)}}); + ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program)); + + // im2col-matmul + const TensorShape src_shape = src->Shape(); + const TensorShape output_shape = output->Shape(); + + const uint32_t batch = onnxruntime::narrow(src_shape[0]); + const uint32_t src_height = onnxruntime::narrow(src_shape[is_channels_last ? 1 : 2]); + const uint32_t src_width = onnxruntime::narrow(src_shape[is_channels_last ? 2 : 3]); + const uint32_t output_height = onnxruntime::narrow(output_shape[is_channels_last ? 1 : 2]); + const uint32_t output_width = onnxruntime::narrow(output_shape[is_channels_last ? 2 : 3]); + + const uint32_t im2col_m = output_height * output_width; + const uint32_t im2col_k = kernel_height * kernel_width * channel_input; + const uint32_t im2col_n = channel_output; + + const auto [tile_m, tile_n] = ChooseTileSize(im2col_m, im2col_n); + const uint32_t workgroup_size = tile_n; + + // Check the device's subgroup size before shader compilation to avoid potential performance penalties + // associated with conditional checks in the shader runtime. + // + // Ensure the subgroup size must be greater than or equal to `tile_m` to safely enable `use_subgroup`. + // If the status of this condition is uncertain, the feature must be disabled. + const bool use_subgroup = false; + Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, use_subgroup}; + im2col_mm_program.SetWorkgroupSize(workgroup_size); + + const uint32_t M_tiles = ceil_div(im2col_m, tile_m); + const uint32_t N_tiles = ceil_div(im2col_n, tile_n); + im2col_mm_program.SetDispatchGroupSize(M_tiles, N_tiles, batch); + + im2col_mm_program.AddInput({src, + ProgramTensorMetadataDependency::TypeAndRank, + 4}); + im2col_mm_program.AddInput({&ohwi_weight, + ProgramTensorMetadataDependency::TypeAndRank, + 4}); + if (has_bias) { + im2col_mm_program.AddInput({bias, + ProgramTensorMetadataDependency::TypeAndRank}); + } + im2col_mm_program.AddOutput({output, + ProgramTensorMetadataDependency::TypeAndRank}); + im2col_mm_program.AddUniformVariables({{batch}, + {src_height}, + {src_width}, + {channel_input}, + {kernel_height}, + {kernel_width}, + {output_height}, + {output_width}, + {im2col_m}, + {im2col_k}, + {im2col_n}, + {M_tiles}, + {N_tiles}, + {ceil_div(ceil_div(im2col_k, 4u), 4u)}, + {dilations}, + {pads}, + {strides}}); + im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, use_subgroup); + + return context.RunProgram(im2col_mm_program); +} + +bool CanApplyIm2ColMatMulProgram(ComputeContext& context, + const bool is_channels_last, + const ActivationKind activation_kind, + const TensorShape weight_shape, + const AutoPadType auto_pad, + const uint32_t group) { + if (!IsDeviceSupported(context)) { + return false; + } + + // TODO: Support !is_channels_last + // TODO: Support fuse + // TODO: Support auto pad + // TODO: Support group conv + if (!is_channels_last || activation_kind != ActivationKind::None || auto_pad != AutoPadType::NOTSET || group != 1) { + return false; + } + + // TODO: Support conv1d + // TODO: Support conv2d_1x1 + const uint32_t kernel_height = onnxruntime::narrow(weight_shape[2]); + const uint32_t kernel_width = onnxruntime::narrow(weight_shape[3]); + if (kernel_height == 1 || kernel_width == 1) { + return false; + } + + // TODO: Support channel input vec1 + const uint32_t channel_input = onnxruntime::narrow(weight_shape[1]); + if (channel_input % 4 != 0) { + return false; + } + + return true; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h new file mode 100644 index 0000000000000..11b98db8554e4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/nn/fuse_utils.h" + +namespace onnxruntime { +namespace webgpu { + +// Transpose OIHW Weight to OHWI +class OIHW2OHWIProgram final : public Program { + public: + OIHW2OHWIProgram() : Program("OIHW2OHWI") {} + + Status GenerateShaderCode(ShaderHelper& shader) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"O", ProgramUniformVariableDataType::Uint32}, + {"I", ProgramUniformVariableDataType::Uint32}, + {"H", ProgramUniformVariableDataType::Uint32}, + {"W", ProgramUniformVariableDataType::Uint32}, + {"Ci_tiles", ProgramUniformVariableDataType::Uint32}, + {"H_W_tiles", ProgramUniformVariableDataType::Uint32}); +}; + +class Im2ColMatMulProgram final : public Program { + public: + Im2ColMatMulProgram(bool has_bias, + uint32_t tile_m, + uint32_t tile_n, + bool use_subgroup) : Program("Im2ColMatMul"), + has_bias_(has_bias), + tile_m_(tile_m), + tile_n_(tile_n), + use_subgroup_(use_subgroup) {} + + Status GenerateShaderCode(ShaderHelper& shader) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"batch", ProgramUniformVariableDataType::Uint32}, + {"src_h", ProgramUniformVariableDataType::Uint32}, + {"src_w", ProgramUniformVariableDataType::Uint32}, + {"channel_i", ProgramUniformVariableDataType::Uint32}, + {"kernel_h", ProgramUniformVariableDataType::Uint32}, + {"kernel_w", ProgramUniformVariableDataType::Uint32}, + {"output_h", ProgramUniformVariableDataType::Uint32}, + {"output_w", ProgramUniformVariableDataType::Uint32}, + {"im2col_m", ProgramUniformVariableDataType::Uint32}, + {"im2col_k", ProgramUniformVariableDataType::Uint32}, + {"im2col_n", ProgramUniformVariableDataType::Uint32}, + {"M_tiles", ProgramUniformVariableDataType::Uint32}, + {"N_tiles", ProgramUniformVariableDataType::Uint32}, + {"K_tiles", ProgramUniformVariableDataType::Uint32}, + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"pads", ProgramUniformVariableDataType::Uint32}, + {"strides", ProgramUniformVariableDataType::Uint32}); + + private: + bool has_bias_; + + uint32_t tile_m_; + uint32_t tile_n_; + bool use_subgroup_; +}; + +bool CanApplyIm2ColMatMulProgram(ComputeContext& context, + const bool is_channels_last, + const ActivationKind activation_kind, + const TensorShape kernel_shape, + const AutoPadType auto_pad, + const uint32_t group); + +Status ApplyIm2ColMatMulProgram(ComputeContext& context, + const bool is_channels_last, + const std::vector& dilations, + const std::vector& pads, + const std::vector& strides, + Tensor* output); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template new file mode 100644 index 0000000000000..2f64525469561 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param has_bias +#param tile_m +#param tile_n +#param use_subgroup + +#use .getByOffset .setByOffset + +// im2col access for src: [N, H_i, W_i, C_i / 4] (vec4-packed NHWC) +// Conceptual Matrix Shape: N * (H_o * W_o) x (K_h * K_w * C_i / 4) +fn load_src(batch : u32, m : u32, k_packed_idx : u32) -> src_value_t { + if (batch >= uniforms.batch || m >= uniforms.im2col_m || k_packed_idx * 4 >= uniforms.im2col_k) { + return src_value_t(); + } + + let channel_i_v4 = uniforms.channel_i / 4; + + // 1. Decompose M index (H_o * W_o) into (h_idx, w_idx) + let h_idx = m / uniforms.output_w; // Output H index (H_o) + let w_idx = m % uniforms.output_w; // Output W index (W_o) + + // 2. Decompose K index into (k_h, k_w, c_i_v4_idx) + let c_i_v4_idx = k_packed_idx % channel_i_v4; + let k_h_w_idx = k_packed_idx / channel_i_v4; + let k_h = k_h_w_idx / uniforms.kernel_w; // Kernel Row + let k_w = k_h_w_idx % uniforms.kernel_w; // Kernel Column + + // 3. Calculate the coordinate in the padded input tensor + let src_h_coord_padded = h_idx * uniforms.strides.x + k_h * uniforms.dilations.x; + let src_w_coord_padded = w_idx * uniforms.strides.y + k_w * uniforms.dilations.y; + + // 4. Calculate the coordinate in the original input tensor + let src_h_coord : i32 = i32(src_h_coord_padded) - i32(uniforms.pads.x); + let src_w_coord : i32 = i32(src_w_coord_padded) - i32(uniforms.pads.z); + + // 5. Check for padding/out-of-bounds + if (src_h_coord < 0 || src_h_coord >= i32(uniforms.src_h) || + src_w_coord < 0 || src_w_coord >= i32(uniforms.src_w)) { + return src_value_t(); + } + + // 6. Calculate final NHWC/vec4 index + let src_idx = batch * uniforms.src_h * uniforms.src_w * channel_i_v4 + + u32(src_h_coord) * uniforms.src_w * channel_i_v4 + + u32(src_w_coord) * channel_i_v4 + + c_i_v4_idx; + return src.getByOffset(src_idx); +} + +// weight shape: [Co, K_h, K_w, C_i / 4] (vec4-packed CoHWCi) +fn load_weight(n : u32, k_packed_idx : u32) -> weight_value_t { + if (n < uniforms.im2col_n && k_packed_idx < uniforms.im2col_k / 4) { + let weight_idx = n * uniforms.im2col_k / 4 + + k_packed_idx; + return weight.getByOffset(weight_idx); + } + return weight_value_t(); +} + +fn load_bias(n : u32) -> output_element_t { +#if has_bias + if (n < uniforms.im2col_n) { + return output_element_t(bias[n]); + } +#endif + return output_element_t(); +} + +// output shape: [N, H_o, W_o, C_o] (NHWC) +fn write_output(batch : u32, m : u32, n : u32, value : output_element_t) { + if (batch < uniforms.batch && m < uniforms.im2col_m && n < uniforms.im2col_n) { + let output_idx = batch * uniforms.im2col_m * uniforms.im2col_n + + m * uniforms.im2col_n + + n; + output.setByOffset(output_idx, value); + } +} + +const TILE_M_SIZE : u32 = tile_m; +const TILE_N_SIZE : u32 = tile_n; +const TILE_K_VEC_SIZE : u32 = 4; + +var src_tile : array, TILE_K_VEC_SIZE>; +var weight_tile : array, TILE_K_VEC_SIZE>; + +$MAIN { + let batch = workgroup_idx / (uniforms.M_tiles * uniforms.N_tiles); + let m_global_base = ((workgroup_idx / uniforms.N_tiles) % uniforms.M_tiles) * TILE_M_SIZE; + let n_global_base = (workgroup_idx % uniforms.N_tiles) * TILE_N_SIZE; + + var results : array; + for (var k_idx = 0u; k_idx < uniforms.K_tiles; k_idx++) { + for (var src_m = 0u; src_m < TILE_M_SIZE; src_m += 16u) { + // Loads a 16x4 vec of src into the workgroup memory. + let load_src_m = src_m + local_idx / 4; + let load_src_k = local_idx % 4; + + src_tile[load_src_k][load_src_m] = load_src(batch, + m_global_base + load_src_m, + k_idx * TILE_K_VEC_SIZE + load_src_k); + } + + for (var weight_n = 0u; weight_n < TILE_N_SIZE; weight_n += 16u) { + // Loads a 16x4 vec of weight into the workgroup memory. + let load_weight_n = weight_n + local_idx / 4; + let load_weight_k = local_idx % 4; + + weight_tile[load_weight_k][load_weight_n] = load_weight(n_global_base + load_weight_n, + k_idx * TILE_K_VEC_SIZE + load_weight_k); + } + workgroupBarrier(); + + for (var inner_k_idx = 0u; inner_k_idx < TILE_K_VEC_SIZE; inner_k_idx++) { + let weight_data = weight_tile[inner_k_idx][local_idx]; +#if use_subgroup + let src_data = src_tile[inner_k_idx][sg_id]; + for (var m_idx = 0u; m_idx < TILE_M_SIZE; m_idx++) { + results[m_idx] += output_element_t(dot(weight_data, subgroupShuffle(src_data, m_idx))); + } +#else + for (var m_idx = 0u; m_idx < TILE_M_SIZE; m_idx++) { + results[m_idx] += output_element_t(dot(weight_data, src_tile[inner_k_idx][m_idx])); + } +#endif + } + workgroupBarrier(); + } + + let m_base = m_global_base; + let n_base = n_global_base + local_idx; + + let bias = load_bias(n_base); + for (var m_idx = 0u; m_idx < TILE_M_SIZE; m_idx++) { + var output_data = results[m_idx] + bias; + write_output(batch, m_base + m_idx, n_base, output_data); + } +} // MAIN diff --git a/onnxruntime/core/providers/webgpu/nn/oihw_to_ohwi.wgsl.template b/onnxruntime/core/providers/webgpu/nn/oihw_to_ohwi.wgsl.template new file mode 100644 index 0000000000000..dfbe7dde28b53 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/oihw_to_ohwi.wgsl.template @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#use .getByOffset .setByOffset + +fn load_src(co : u32, ci : u32, h_w : u32) -> src_element_t { + if (co < uniforms.O && ci < uniforms.I && h_w < uniforms.H * uniforms.W) { + let offset = co * uniforms.I * uniforms.H * uniforms.W + + ci * uniforms.H * uniforms.W + + h_w; + return src.getByOffset(offset); + } + return src_element_t(); +} + +fn write_output(co : u32, h_w : u32, ci : u32, value : output_element_t) { + if (co < uniforms.O && ci < uniforms.I && h_w < uniforms.H * uniforms.W) { + let offset = co * uniforms.H * uniforms.W * uniforms.I + + h_w * uniforms.I + + ci; + output.setByOffset(offset, value); + } +} + +var data_cache : array, 4>; + +$MAIN { + let group_co : u32 = workgroup_idx / uniforms.Ci_tiles; + let group_ci : u32 = (workgroup_idx % uniforms.Ci_tiles) * 64; + + if (group_co >= uniforms.O || group_ci >= uniforms.I) { + return; + } + + for (var h_w_idx = 0u; h_w_idx < uniforms.H_W_tiles; h_w_idx++) { + // load + for (var ci_idx = 0u; ci_idx < 64u; ci_idx += 16u) { + let load_ci_idx = ci_idx + local_idx / 4; + let load_h_w_idx = local_idx % 4; + + data_cache[load_h_w_idx][load_ci_idx] = load_src(group_co, + group_ci + load_ci_idx, + h_w_idx * 4 + load_h_w_idx); + } + workgroupBarrier(); + + // store + for (var local_h_w_idx = 0u; local_h_w_idx < 4u; local_h_w_idx++) { + let output_data = data_cache[local_h_w_idx][local_idx]; + write_output(group_co, h_w_idx * 4 + local_h_w_idx, group_ci + local_idx, output_data); + } + workgroupBarrier(); + } +} // MAIN diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index 8ef229b9ef69f..f2996da4bd29e 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -119,6 +119,7 @@ Status ProgramManager::Build(const ProgramBase& program, wgpu::ShaderModuleDescriptor descriptor{}; descriptor.nextInChain = &wgsl_source; + descriptor.label = program.Name().c_str(); auto shader_module = device.CreateShaderModule(&descriptor); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index b08649cbd5d5b..e235506a44c89 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -140,8 +140,12 @@ namespace { // Validate if the tensor element type matches the program variable data type Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType var_type, bool is_atomic = false) { if (is_atomic) { - // float32 is not a valid data type for atomic. However the data may be bitcast-ed to i32 and used to simulate atomic operation using atomicCompareExchangeWeak. - ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || var_type == ProgramVariableDataType::Uint32 || var_type == ProgramVariableDataType::Float32, + // float32, float32x4 and float16x4 are not valid data types for atomic. However the data may be bitcast-ed to i32 and used to simulate atomic operation using atomicCompareExchangeWeak. + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || + var_type == ProgramVariableDataType::Uint32 || + var_type == ProgramVariableDataType::Float32 || + var_type == ProgramVariableDataType::Float16x4 || + var_type == ProgramVariableDataType::Float32x4, "Unexpected program variable type ", int(var_type), " for atomic variable"); } diff --git a/onnxruntime/core/providers/webgpu/tensor/slice.cc b/onnxruntime/core/providers/webgpu/tensor/slice.cc index 7e8b434431781..5f59fecc425e2 100644 --- a/onnxruntime/core/providers/webgpu/tensor/slice.cc +++ b/onnxruntime/core/providers/webgpu/tensor/slice.cc @@ -92,14 +92,28 @@ Status SliceProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } +static std::vector getInt64Input(const Tensor* tensor) { + if (tensor->IsDataType()) { + return std::vector(tensor->DataAsSpan().begin(), tensor->DataAsSpan().end()); + } + ORT_ENFORCE(tensor->IsDataType(), "Expected tensor of type int32 or int64"); + std::vector result; + auto span = tensor->DataAsSpan(); + result.reserve(span.size()); + for (auto v : span) { + result.push_back(static_cast(v)); + } + return result; +} + Status Slice::ComputeInternal(ComputeContext& context) const { // READ INPUTS const Tensor* input_tensor = context.Input(0); const TensorShape& input_shape = input_tensor->Shape(); auto input_rank = input_shape.NumDimensions(); - auto starts_raw = attr_starts_.empty() ? context.Input(1)->DataAsSpan() : gsl::make_span(attr_starts_); - auto ends_raw = attr_ends_.empty() ? context.Input(2)->DataAsSpan() : gsl::make_span(attr_ends_); + auto starts_raw = attr_starts_.empty() ? getInt64Input(context.Input(1)) : attr_starts_; + auto ends_raw = attr_ends_.empty() ? getInt64Input(context.Input(2)) : attr_ends_; ORT_ENFORCE(starts_raw.size() == ends_raw.size(), "starts and ends must have the same size"); @@ -126,7 +140,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { axes_default.push_back(i); } } - auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? gsl::make_span(axes_default) : axes_tensor->DataAsSpan()) : gsl::make_span(attr_axes_); + auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? axes_default : getInt64Input(axes_tensor)) : attr_axes_; std::vector steps_default; if (steps_tensor == nullptr) { @@ -135,7 +149,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { steps_default.push_back(1); } } - auto steps_raw = steps_tensor == nullptr ? gsl::make_span(steps_default) : steps_tensor->DataAsSpan(); + auto steps_raw = steps_tensor == nullptr ? steps_default : getInt64Input(steps_tensor); // get final axes std::vector axes, axes_fixed; diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index cec321d0da80e..5415d4a5ead5b 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -108,7 +108,7 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, +Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span permutations, const Tensor& input, Tensor& output) { const auto& input_shape = input.Shape(); diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index b62a419fa12bc..5e9ccc6750cd6 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -16,7 +16,7 @@ class Transpose final : public WebGpuKernel, public TransposeBase { Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { } Status ComputeInternal(ComputeContext& context) const override; - static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span permutations, const Tensor& input, Tensor& output); + static Status DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span permutations, const Tensor& input, Tensor& output); constexpr static uint32_t TILE_SIZE = 16; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 28decb076951e..b8d5adc421124 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -147,6 +147,9 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi // create program manager program_mgr_ = std::make_unique(*this); + // create split-k config + split_k_config_ = std::make_unique(adapter_info_); + // set query type #if !defined(__wasm__) if (DeviceHasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) { @@ -178,7 +181,7 @@ Status WebGpuContext::Wait(wgpu::Future f) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } -Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { +Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) { const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); @@ -288,8 +291,8 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch); if (is_profiling_) { - PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(), - context.KernelContext().GetOpType(), + PendingKernelInfo pending_kernel_info(context.NodeName(), + context.OpType(), program.Name(), key, inputs, @@ -442,7 +445,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; WGPUBuffer uniform_buffer = nullptr; - const webgpu::BufferManager& buffer_mgr = ComputeContext::BufferManagerAccessor::Get(context); + const webgpu::BufferManager& buffer_mgr = ComputeContextBase::BufferManagerAccessor::Get(context); if (uniform_buffer_total_size > 0) { std::vector uniform_data_buffer(uniform_buffer_total_size); @@ -910,13 +913,6 @@ void WebGpuContext::ReleaseGraphResources(std::vector WebGpuContextFactory::contexts_; std::mutex WebGpuContextFactory::mutex_; std::once_flag WebGpuContextFactory::init_default_flag_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index bd7dae75f2e2d..84dfb47ef4687 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -5,7 +5,6 @@ #include #include -#include #include "core/providers/webgpu/webgpu_external_header.h" @@ -23,7 +22,7 @@ class Tensor; namespace webgpu { class WebGpuContext; -class ComputeContext; +class ComputeContextBase; class ProgramBase; // Definition for CapturedCommandInfo in the webgpu namespace @@ -152,6 +151,13 @@ class WebGpuContext final { return validation_mode_; } + // + // Get Split-K configuration. + // + const SplitKConfig& GetSplitKConfig() const { + return *split_k_config_; + } + void StartProfiling(); void CollectProfilingData(profiling::Events& events); void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events); @@ -170,16 +176,9 @@ class WebGpuContext final { // Status PopErrorScope(); - Status Run(ComputeContext& context, const ProgramBase& program); + Status Run(ComputeContextBase& context, const ProgramBase& program); void OnRunEnd(); - // - // Get Split-K configuration. - // - // `split_k_config_` won't be initialized until the first call to this method. - // - const SplitKConfig& GetSplitKConfig(); - private: enum class TimestampQueryType { None = 0, @@ -277,7 +276,7 @@ class WebGpuContext final { uint32_t num_pending_dispatches_ = 0; const uint32_t max_num_pending_dispatches_ = 16; - std::optional split_k_config_; + std::unique_ptr split_k_config_; // profiling TimestampQueryType query_type_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index e0b84fef51f1f..6b764d51bcf75 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -794,8 +794,7 @@ using namespace webgpu; WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, WebGpuContext& context, WebGpuExecutionProviderConfig&& config) - : IExecutionProvider{kWebGpuExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)}, + : IExecutionProvider{kWebGpuExecutionProvider, WebGpuDevice}, context_id_{context_id}, context_{context}, preferred_data_layout_{config.data_layout}, @@ -935,13 +934,14 @@ std::unique_ptr WebGpuExecutionProvider::GetEx std::optional WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain, std::string_view node_op_type, DataLayout target_data_layout) const { - if (target_data_layout != DataLayout::NHWC) { - return std::nullopt; - } - // NHWC for Resize operator is not implemented on kWebGpuExecutionProvider if (node_domain == kOnnxDomain && node_op_type == "Resize") { - return false; + return target_data_layout != DataLayout::NHWC; + } + + // WebGPU perfer NCHW for InstanceNormalization due to a better performance + if (node_domain == kOnnxDomain && node_op_type == "InstanceNormalization") { + return target_data_layout != DataLayout::NHWC; } return std::nullopt; diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc index 8d6ae6caeaf83..ea38e9415e1fe 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc @@ -11,25 +11,58 @@ namespace webgpu { WebGpuKernel::WebGpuKernel(const OpKernelInfo& info) : OpKernel(info), - ep_(*static_cast(info.GetExecutionProvider())) { + ep_(*static_cast(info.GetExecutionProvider())), + webgpu_context_(WebGpuContextFactory::GetContext(ep_.GetDeviceId())) { } Status WebGpuKernel::Compute(OpKernelContext* p_op_kernel_context) const { - WebGpuContext& webgpu_context = WebGpuContextFactory::GetContext(ep_.GetDeviceId()); - ComputeContext context{*p_op_kernel_context, *this, ep_, webgpu_context}; + ComputeContext context{webgpu_context_, + ep_, + *this, + *p_op_kernel_context}; - if (webgpu_context.ValidationMode() >= ValidationMode::Full) { - webgpu_context.PushErrorScope(); + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + webgpu_context_.PushErrorScope(); } Status s = ComputeInternal(context); - if (webgpu_context.ValidationMode() >= ValidationMode::Full) { - ORT_RETURN_IF_ERROR(webgpu_context.PopErrorScope()); + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); } return s; } +Status WebGpuKernel::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) { + ComputeContextBase context{webgpu_context_, ep_, *this}; + + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + webgpu_context_.PushErrorScope(); + } + + // Currently, ORT does not allow using prepacked weights in non-CPU EPs. + // So we do not pass prepacked_weights to PrePackInternal. + // Kernel implementation that supports prepacking should manage its own storage. + + Status s = PrePackInternal(context, tensor, input_idx, alloc, is_packed); + + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); + } + + return s; +} + +Status WebGpuKernel::PrePackInternal(ComputeContextBase& /*context*/, + const Tensor& /*tensor*/, + int /*input_idx*/, + AllocatorPtr /*alloc*/, + /*out*/ bool& is_packed) { + is_packed = false; + return Status::OK(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index 3c750e305421c..2c57991c6ee35 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -23,8 +23,41 @@ class WebGpuKernel : public OpKernel { virtual Status ComputeInternal(ComputeContext& context) const = 0; + // Overrides OpKernel::PrePack to handle constant tensor pre-processing for WebGPU kernels. + // This method creates a ComputeContextBase and delegates to PrePackInternal. + // + // NOTE: Currently, ORT does not allow using prepacked weights in non-CPU EPs, so the + // prepacked_weights parameter is not passed to PrePackInternal. Kernel implementations + // that support prepacking should manage their own storage. + Status PrePack(const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + // Virtual method that allows derived kernels to pre-process constant tensors during initialization. + // + // This method is called during kernel initialization when constant tensors are available, + // allowing kernels to perform operations like tensor transposition or format conversion + // before the first Compute call. + // + // @param context The WebGPU compute context base providing access to the execution environment. + // @param tensor The constant tensor to potentially pre-process. + // @param input_idx The index of this input in the kernel's input list. + // @param alloc The allocator to use for any new tensor allocations. + // @param is_packed Output parameter. Set to true if the tensor was pre-packed/processed, + // false otherwise. The default implementation sets this to false. + // + // @return Status::OK() on success, or an error status on failure. + virtual Status PrePackInternal(ComputeContextBase& context, + const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed); + private: const WebGpuExecutionProvider& ep_; + WebGpuContext& webgpu_context_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 568d29a96cb88..824cfb02c22f0 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -1,6 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "core/providers/webgpu/webgpu_utils.h" + +#include +#include "core/providers/webgpu/shader_variable.h" + namespace onnxruntime { namespace webgpu { @@ -21,27 +25,24 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components return TensorShape(shape_vector); } -SplitKConfig SplitKConfig::GetSplitKConfig(const wgpu::AdapterInfo& adapter_info) { - SplitKConfig config = {}; - +SplitKConfig::SplitKConfig(const wgpu::AdapterInfo& adapter_info) { if (adapter_info.vendor == std::string_view{"intel"}) { if (adapter_info.architecture == std::string_view{"xe-2lpg"} || adapter_info.architecture == std::string_view{"xe-2hpg"} || adapter_info.architecture == std::string_view{"xe-lpg"} || adapter_info.architecture == std::string_view{"gen-12hp"}) { - config.enable_split_k_ = true; + enable_split_k_ = true; // Below thresholds are only verified on the above Intel GPUs without any regressions. The // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more // atomic calls for each output value. - config.split_dim_inner_ = 256; - config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; - config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9; - config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; + split_dim_inner_ = 256; + min_dim_inner_with_split_k_ = split_dim_inner_ * 2; + max_dim_inner_with_split_k_ = split_dim_inner_ * 9; + max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; } } - return config; } bool SplitKConfig::UseSplitK( @@ -80,5 +81,23 @@ uint32_t SplitKConfig::GetSplitDimInner() const { return split_dim_inner_; } +std::string GenerateAtomicAddNonIntegerCode(const ShaderVariableHelper& output, const std::string& offset, const std::string& output_type, const std::string& add_value) { + std::ostringstream ss; + + std::string get_output_by_offset = output.GetByOffset(offset); + ss << "while (true) {\n" + << " let old_output_i32 = atomicLoad(&" << get_output_by_offset << ");\n" + << " let old_output_" << output_type << " = bitcast<" << output_type << ">(old_output_i32);\n" + << " let new_output_" << output_type << " = old_output_" << output_type << " + " << add_value << ";\n" + << " let new_output_i32 = bitcast(new_output_" << output_type << ");\n" + << " let output_compare_exchange = atomicCompareExchangeWeak(&" << get_output_by_offset << ", old_output_i32, new_output_i32);\n" + << " if (output_compare_exchange.old_value == old_output_i32) {\n" + << " break;\n" + << " }\n" + << "}\n"; + + return ss.str(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index d45b9bf4dd119..0aa47371f6752 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -13,6 +13,8 @@ namespace onnxruntime { namespace webgpu { +class ShaderVariableHelper; + /** * Returns the maximum number of components `N` to be used as `vecN` for the given size. */ @@ -91,9 +93,12 @@ inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, c return {new_data_type, new_shape, const_cast(tensor.DataRaw()), tensor.Location()}; } +/** + * Configuration for Split-K optimization (Conv|MatMul). + */ class SplitKConfig { public: - static SplitKConfig GetSplitKConfig(const wgpu::AdapterInfo& adapter_info); + explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, @@ -110,5 +115,21 @@ class SplitKConfig { float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 0.0f; }; +/** + * Generates WGSL (WebGPU Shading Language) code for performing an atomic add operation + * on a non-integer value (e.g., floating-point) in a shader. + * + * Since WGSL natively supports atomic operations only on integer types, this function + * generates code that emulates atomic addition for non-integer types using a compare-and-swap loop. + * + * @param output A reference to the ShaderVariableHelper representing the atomic variable + * to be updated. This encapsulates the variable's name and access logic. + * @param offset The offset or index within the atomic variable where the operation is applied. + * @param output_type The WGSL type of the value being added (e.g., "f32"). + * @param add_value The expression or variable representing the value to add. + * @return A string containing the generated WGSL code for the atomic add operation. + */ +std::string GenerateAtomicAddNonIntegerCode(const ShaderVariableHelper& output, const std::string& offset, const std::string& output_type, const std::string& add_value); + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/session/abi_ep_types.cc b/onnxruntime/core/session/abi_ep_types.cc index 14764251898aa..5f45ea0a2b808 100644 --- a/onnxruntime/core/session/abi_ep_types.cc +++ b/onnxruntime/core/session/abi_ep_types.cc @@ -10,6 +10,10 @@ #include "core/graph/ep_api_types.h" #include "core/session/abi_devices.h" +OrtEpGraphSupportInfo::OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph, + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) + : ort_graph(graph), kernel_lookup{kernel_lookup} {} + onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span nodes, const OrtNodeFusionOptions* optional_fusion_options) { std::vector ep_nodes; diff --git a/onnxruntime/core/session/abi_ep_types.h b/onnxruntime/core/session/abi_ep_types.h index eb68d79a24279..deaadf7c67e6e 100644 --- a/onnxruntime/core/session/abi_ep_types.h +++ b/onnxruntime/core/session/abi_ep_types.h @@ -10,6 +10,7 @@ #include "core/common/inlined_containers_fwd.h" #include "core/common/status.h" +#include "core/framework/execution_provider.h" #include "core/session/onnxruntime_c_api.h" namespace onnxruntime { @@ -39,7 +40,8 @@ struct OrtEpGraphSupportInfo { OrtNodeFusionOptions fusion_options = {}; }; - explicit OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph) : ort_graph(graph) {} + OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph, + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup); onnxruntime::Status AddNodesToFuse(gsl::span nodes, const OrtNodeFusionOptions* node_fusion_options = nullptr); @@ -47,4 +49,5 @@ struct OrtEpGraphSupportInfo { const onnxruntime::EpGraph& ort_graph; std::vector node_groupings; + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup; }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 4d4dea9cb444c..e5523dc78b5d2 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -162,7 +162,6 @@ static bool AreAllComputeNodesAssignedToCudaOrJsOrDmlEpWebGpuEp(const Graph& gra // Empty node provider means CPU EP if (!node_provider.empty() && !(node_provider == kCudaExecutionProvider || - node_provider == kRocmExecutionProvider || node_provider == kJsExecutionProvider || node_provider == kWebGpuExecutionProvider || node_provider == kDmlExecutionProvider) && @@ -1320,6 +1319,29 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool *session_logger_)); } + // We choose to convert initializers into OrtValues before partitioning here so plug-in EPs could + // take advantage of the initializers being in OrtValue format and not to deal with protobuf. + // + // The initializers data is transferred to an OrtValue. The original TensorProto is replaced + // with a TensorProto that has the same data type, shape and name. However, its external data + // is used in a non-standard way. The location is set to a string constant utils::kTensorProtoMemoryAddressTag, + // The file offset is set to the address of the OrtValue's data buffer, and the length is set to the size of the + // OrtValue's data buffer. Because this external location is non-standard, onnx code can not handle it. For this reason, + // we do not convert them at the graph constructor because Node::ToProto() reconstructs Graph instances for subgraphs + // and we do not want to have initializers converted at shape inference time, as Resolve() is called from EPs when + // op_types are not assigned yet. + // + // If any transformations are applied later, they would not introduce any in-memory initializers, + // type and shape inference would run only on any newly added nodes and any new initializers + // will be converted at session finalization time. + // + // The conversion is performed using the following steps (within ConvertInitializersIntoOrtValues()) + // constexpr const bool use_tensor_buffer_true = true; + // auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get(), tensor_proto.name(), + // use_tensor_buffer_true); + // ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(tensor_proto_to_add, ort_value)); + ORT_RETURN_IF_ERROR_SESSIONID_(graph.ConvertInitializersIntoOrtValues()); + auto apply_transformer_once = [](const GraphTransformer& transformer, const logging::Logger& logger, Graph& graph, bool* is_graph_modified = nullptr) -> onnxruntime::common::Status { bool modified = false; @@ -2269,7 +2291,7 @@ common::Status InferenceSession::Initialize() { "Session initialization canceled due to user request."); } - // Currently graph capture is only considered by CUDA EP, TRT EP, ROCM EP and JS EP. + // Currently graph capture is only considered by CUDA EP, TRT EP and JS EP. // // Check for CUDA EP: // If the CUDA EP is part of the providers list for this session AND @@ -2289,16 +2311,9 @@ common::Status InferenceSession::Initialize() { // All the "compute" graph nodes have been assigned to the JS EP, // Then the JS EP is cached for triggering a ReplayGraph() in Run(). // - // Check for ROCM EP: - // If the ROCM EP is part of the providers list for this session AND - // The ROCM EP is configured to do a graph capture AND - // All the "compute" graph nodes have been assigned to the ROCM EP, - // Then the ROCM EP is cached for triggering a ReplayGraph() in Run(). - // std::vector graph_support_ep_list = { onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, onnxruntime::kJsExecutionProvider, onnxruntime::kWebGpuExecutionProvider, onnxruntime::kDmlExecutionProvider}; @@ -2321,7 +2336,6 @@ common::Status InferenceSession::Initialize() { } if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || - strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0 || strcmp(target_ep->Type().c_str(), onnxruntime::kJsExecutionProvider) == 0 || strcmp(target_ep->Type().c_str(), onnxruntime::kWebGpuExecutionProvider) == 0 || strcmp(target_ep->Type().c_str(), onnxruntime::kDmlExecutionProvider) == 0) { @@ -2515,6 +2529,12 @@ common::Status InferenceSession::Initialize() { LOGS(*session_logger_, ERROR) << status.ErrorMessage(); }); } + ORT_CATCH(const OnnxRuntimeException& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = Status(ex.Category(), ex.Code(), MakeString("Exception during initialization: ", ex.what())); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + }); + } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Exception during initialization: ", ex.what()); @@ -2943,6 +2963,8 @@ Status InferenceSession::Run(const RunOptions& run_options, << cached_execution_provider_for_graph_replay_.Type() << " CUDA Graph for this model with tag: " << run_options.run_tag << " with graph annotation id: " << graph_annotation_id; + // log evaluation start to trace logging provider + env.GetTelemetryProvider().LogEvaluationStart(session_id_); ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph(graph_annotation_id)); } else { InlinedVector exec_providers_to_stop; @@ -3134,7 +3156,6 @@ Status InferenceSession::Run(const RunOptions& run_options, // are needed before replaying the captured graph, here run N inference runs recursively until graph captured, // so that users just need one session run to capture the graph. // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, - // N is defined in min_num_runs_before_hip_graph_capture_ for ROCM EP, // and the value could be different for other EP. if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() && cached_execution_provider_for_graph_replay_.AllowGraphCaptureOnRun(graph_annotation_id) && diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 394f69bb15b19..82f7cef4aec49 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3403,7 +3403,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice // create the wrapper class that uses the EP implementation auto stream = std::make_unique(ep_device->device_memory_info->device, - *stream_impl, LoggingManager::DefaultLogger()); + *stream_impl, *LoggingManager::DefaultLogger().ToExternal()); // cast to base type, and to API alias type *ort_stream = static_cast(static_cast(stream.release())); diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index c0e4d32ac0167..f3525d8de7b95 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -256,6 +256,7 @@ ORT_API_STATUS_IMPL(AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ c ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options); + ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options); ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO, diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index 7efb0a68c735d..e89944394aaec 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -6,11 +6,13 @@ #include #include #include +#include #include #include "core/common/semver.h" #include "core/framework/error_code_helper.h" #include "core/framework/func_api.h" +#include "core/framework/op_kernel_info.h" #include "core/framework/ort_value.h" #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" @@ -21,6 +23,8 @@ #include "core/session/abi_ep_types.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/ort_apis.h" +#include "core/session/plugin_ep/ep_kernel_registration.h" +#include "core/session/utils.h" using namespace onnxruntime; namespace OrtExecutionProviderApi { @@ -73,7 +77,7 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInf _In_opt_ const OrtNodeFusionOptions* node_fusion_options) { API_IMPL_BEGIN if (ort_graph_support_info == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid OrtGraph instance"); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid OrtEpGraphSupportInfo instance"); } if (num_nodes == 0 || nodes == nullptr) { @@ -233,6 +237,351 @@ ORT_API(void, ReleaseHardwareDevice, _Frees_ptr_opt_ OrtHardwareDevice* device) delete device; } +ORT_API_STATUS_IMPL(CreateKernelRegistry, _Outptr_ OrtKernelRegistry** kernel_registry) { + API_IMPL_BEGIN + auto unique_kernel_registry = std::make_unique(); + + *kernel_registry = reinterpret_cast(unique_kernel_registry.release()); + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseKernelRegistry, _Frees_ptr_opt_ OrtKernelRegistry* kernel_registry) { + delete reinterpret_cast(kernel_registry); +} + +ORT_API_STATUS_IMPL(KernelRegistry_AddKernel, _In_ OrtKernelRegistry* kernel_registry, + _In_ const OrtKernelDef* kernel_def, _In_ OrtKernelCreateFunc kernel_create_func, + _In_ void* kernel_create_func_state) { + API_IMPL_BEGIN + if (kernel_registry == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null OrtKernelRegistry"); + } + + if (kernel_def == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null OrtKernelDef"); + } + + if (kernel_create_func == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null OrtKernelCreateFunc"); + } + + auto* internal_kernel_def = reinterpret_cast(kernel_def); + onnxruntime::KernelCreateInfo kernel_create_info = MakePluginEpKernelCreateInfo(internal_kernel_def, + kernel_create_func, + kernel_create_func_state); + + auto* actual_registry = reinterpret_cast(kernel_registry); + ORT_API_RETURN_IF_STATUS_NOT_OK(actual_registry->Register(std::move(kernel_create_info))); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(CreateKernelDefBuilder, _Outptr_ OrtKernelDefBuilder** kernel_def_builder_out) { + API_IMPL_BEGIN + auto builder = onnxruntime::KernelDefBuilder::Create(); + *kernel_def_builder_out = reinterpret_cast(builder.release()); + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseKernelDefBuilder, _Frees_ptr_opt_ OrtKernelDefBuilder* kernel_def_builder) { + delete reinterpret_cast(kernel_def_builder); +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_SetOperatorType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* op_type) { + API_IMPL_BEGIN + auto* builder = reinterpret_cast(kernel_def_builder); + builder->SetName(op_type); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_SetDomain, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* domain) { + API_IMPL_BEGIN + auto* builder = reinterpret_cast(kernel_def_builder); + builder->SetDomain(domain); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_SetSinceVersion, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ int since_version_start, _In_ int since_version_end) { + API_IMPL_BEGIN + auto* builder = reinterpret_cast(kernel_def_builder); + + // start version must be >= 1 + if (since_version_start < 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Start version must be >= 1"); + } + + // end version must >= start version + if (since_version_end < since_version_start) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "End version must be >= to the start version"); + } + + builder->SinceVersion(since_version_start, since_version_end); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_SetExecutionProvider, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* ep_name) { + API_IMPL_BEGIN + auto* builder = reinterpret_cast(kernel_def_builder); + builder->Provider(ep_name); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_SetInputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ size_t input_index, _In_ OrtMemType mem_type) { + API_IMPL_BEGIN + auto* builder = reinterpret_cast(kernel_def_builder); + builder->InputMemoryType(mem_type, static_cast(input_index)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_SetOutputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ size_t output_index, _In_ OrtMemType mem_type) { + API_IMPL_BEGIN + auto* builder = reinterpret_cast(kernel_def_builder); + builder->OutputMemoryType(mem_type, static_cast(output_index)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_AddTypeConstraint, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* arg_name, _In_reads_(num_types) const OrtDataType* const* types, + _In_ size_t num_types) { + API_IMPL_BEGIN + if (num_types == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify at least one OrtDataType instance"); + } + + if (types == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of OrtDataType instances"); + } + + if (arg_name == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a valid name for a kernel definition's type constraint"); + } + + std::vector ml_types; + ml_types.reserve(num_types); + + for (size_t i = 0; i < num_types; i++) { + ml_types.push_back(reinterpret_cast(types[i])); + } + + auto* builder = reinterpret_cast(kernel_def_builder); + builder->TypeConstraint(arg_name, std::move(ml_types)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_AddInputOutputAliases, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_reads_(num_io_indices) int const* input_indices, + _In_reads_(num_io_indices) int const* output_indices, + _In_ size_t num_io_indices) { + API_IMPL_BEGIN + if (num_io_indices == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify at least one input/output alias"); + } + + if (input_indices == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of input indices to alias"); + } + + if (output_indices == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of output indices to alias"); + } + + auto* builder = reinterpret_cast(kernel_def_builder); + + if (num_io_indices == 1) { + builder->Alias(input_indices[0], output_indices[0]); + } else { + std::vector> pairs; + pairs.reserve(num_io_indices); + + for (size_t i = 0; i < num_io_indices; ++i) { + pairs.push_back({input_indices[i], output_indices[i]}); + } + + builder->Alias(pairs); + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_AddInputOutputMutableAliases, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_reads_(num_io_indices) int const* input_indices, + _In_reads_(num_io_indices) int const* output_indices, + _In_ size_t num_io_indices) { + API_IMPL_BEGIN + if (num_io_indices == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify at least one input/output alias (mutable)"); + } + + if (input_indices == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a valid array of input indices to alias (mutable)"); + } + + if (output_indices == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a valid array of output indices to alias (mutable)"); + } + + auto* builder = reinterpret_cast(kernel_def_builder); + + if (num_io_indices == 1) { + builder->MayInplace(input_indices[0], output_indices[0]); + } else { + std::vector> pairs; + pairs.reserve(num_io_indices); + + for (size_t i = 0; i < num_io_indices; ++i) { + pairs.push_back({input_indices[i], output_indices[i]}); + } + + builder->MayInplace(pairs); + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDefBuilder_Build, _In_ OrtKernelDefBuilder* kernel_def_builder, + _Outptr_ OrtKernelDef** kernel_def_out) { + API_IMPL_BEGIN + auto* builder = reinterpret_cast(kernel_def_builder); + *kernel_def_out = reinterpret_cast(builder->Build().release()); + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseKernelDef, _Frees_ptr_opt_ OrtKernelDef* kernel_def) { + delete reinterpret_cast(kernel_def); +} + +ORT_API(const char*, KernelDef_GetOperatorType, _In_ const OrtKernelDef* kernel_def) { + return reinterpret_cast(kernel_def)->OpName().c_str(); +} + +ORT_API(const char*, KernelDef_GetDomain, _In_ const OrtKernelDef* kernel_def) { + return reinterpret_cast(kernel_def)->Domain().c_str(); +} + +ORT_API_STATUS_IMPL(KernelDef_GetSinceVersion, _In_ const OrtKernelDef* kernel_def, + _Out_ int* start_version, _Out_ int* end_version) { + API_IMPL_BEGIN + if (kernel_def == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid non-null OrtKernelDef"); + } + + if (start_version == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null `start_version` output parameter"); + } + + if (end_version == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null `end_version` output parameter"); + } + + auto* internal_kernel_def = reinterpret_cast(kernel_def); + internal_kernel_def->SinceVersion(start_version, end_version); + + return nullptr; + API_IMPL_END +} + +ORT_API(const char*, KernelDef_GetExecutionProvider, _In_ const OrtKernelDef* kernel_def) { + return reinterpret_cast(kernel_def)->Provider().c_str(); +} + +ORT_API_STATUS_IMPL(KernelDef_GetInputMemType, _In_ const OrtKernelDef* kernel_def, + _In_ size_t input_index, _Out_ OrtMemType* mem_type) { + API_IMPL_BEGIN + if (kernel_def == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid non-null OrtKernelDef"); + } + + if (mem_type == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null `mem_type` output parameter"); + } + + auto* internal_kernel_def = reinterpret_cast(kernel_def); + *mem_type = internal_kernel_def->InputMemoryType(input_index); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(KernelDef_GetOutputMemType, _In_ const OrtKernelDef* kernel_def, + _In_ size_t output_index, _Out_ OrtMemType* mem_type) { + API_IMPL_BEGIN + if (kernel_def == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid non-null OrtKernelDef"); + } + + if (mem_type == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null `mem_type` output parameter"); + } + + auto* internal_kernel_def = reinterpret_cast(kernel_def); + *mem_type = internal_kernel_def->OutputMemoryType(output_index); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(GetTensorDataType, _In_ ONNXTensorElementDataType elem_type, + _Outptr_ const OrtDataType** out) { + API_IMPL_BEGIN + const DataTypeImpl* ml_type = DataTypeImpl::TensorTypeFromONNXEnum(elem_type); + *out = reinterpret_cast(ml_type); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def) { + API_IMPL_BEGIN + if (out_kernel_def == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null OrtKernelDef output parameter"); + } + + *out_kernel_def = nullptr; + + if (graph_support_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid non-null OrtEpGraphSupportInfo instance"); + } + + if (node == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid non-null OrtNode instance"); + } + + const onnxruntime::EpNode* ep_node = onnxruntime::EpNode::ToInternal(node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "OrtNode created via the ModelEditor API is not supported"); + } + + const onnxruntime::KernelCreateInfo* create_info = + graph_support_info->kernel_lookup.LookUpKernel(ep_node->GetInternalNode()); + + *out_kernel_def = create_info != nullptr ? reinterpret_cast(create_info->kernel_def.get()) + : nullptr; + return nullptr; + API_IMPL_END +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -262,6 +611,31 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::CreateHardwareDevice, &OrtExecutionProviderApi::ReleaseHardwareDevice, + + &OrtExecutionProviderApi::CreateKernelRegistry, + &OrtExecutionProviderApi::ReleaseKernelRegistry, + &OrtExecutionProviderApi::KernelRegistry_AddKernel, + &OrtExecutionProviderApi::CreateKernelDefBuilder, + &OrtExecutionProviderApi::ReleaseKernelDefBuilder, + &OrtExecutionProviderApi::KernelDefBuilder_SetOperatorType, + &OrtExecutionProviderApi::KernelDefBuilder_SetDomain, + &OrtExecutionProviderApi::KernelDefBuilder_SetSinceVersion, + &OrtExecutionProviderApi::KernelDefBuilder_SetExecutionProvider, + &OrtExecutionProviderApi::KernelDefBuilder_SetInputMemType, + &OrtExecutionProviderApi::KernelDefBuilder_SetOutputMemType, + &OrtExecutionProviderApi::KernelDefBuilder_AddTypeConstraint, + &OrtExecutionProviderApi::KernelDefBuilder_AddInputOutputAliases, + &OrtExecutionProviderApi::KernelDefBuilder_AddInputOutputMutableAliases, + &OrtExecutionProviderApi::KernelDefBuilder_Build, + &OrtExecutionProviderApi::ReleaseKernelDef, + &OrtExecutionProviderApi::KernelDef_GetOperatorType, + &OrtExecutionProviderApi::KernelDef_GetDomain, + &OrtExecutionProviderApi::KernelDef_GetSinceVersion, + &OrtExecutionProviderApi::KernelDef_GetExecutionProvider, + &OrtExecutionProviderApi::KernelDef_GetInputMemType, + &OrtExecutionProviderApi::KernelDef_GetOutputMemType, + &OrtExecutionProviderApi::GetTensorDataType, + &OrtExecutionProviderApi::EpGraphSupportInfo_LookUpKernel, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index 129230be4f618..b6a7262ec2008 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -47,4 +47,57 @@ ORT_API_STATUS_IMPL(CreateHardwareDevice, _In_ OrtHardwareDeviceType type, _In_opt_ const OrtKeyValuePairs* metadata, _Out_ OrtHardwareDevice** hardware_device); ORT_API(void, ReleaseHardwareDevice, _Frees_ptr_opt_ OrtHardwareDevice* device); + +// OrtKernelRegistry +ORT_API_STATUS_IMPL(CreateKernelRegistry, _Outptr_ OrtKernelRegistry** kernel_registry); +ORT_API(void, ReleaseKernelRegistry, _Frees_ptr_opt_ OrtKernelRegistry* kernel_registry); +ORT_API_STATUS_IMPL(KernelRegistry_AddKernel, _In_ OrtKernelRegistry* kernel_registry, + _In_ const OrtKernelDef* kernel_def, _In_ OrtKernelCreateFunc kernel_create_func, + _In_ void* kernel_create_func_state); + +// OrtKernelDefBuilder +ORT_API_STATUS_IMPL(CreateKernelDefBuilder, _Outptr_ OrtKernelDefBuilder** kernel_def_builder_out); +ORT_API(void, ReleaseKernelDefBuilder, _Frees_ptr_opt_ OrtKernelDefBuilder* kernel_def_builder); +ORT_API_STATUS_IMPL(KernelDefBuilder_SetOperatorType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* op_type); +ORT_API_STATUS_IMPL(KernelDefBuilder_SetDomain, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* domain); +ORT_API_STATUS_IMPL(KernelDefBuilder_SetSinceVersion, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ int since_version_start, _In_ int since_version_end); +ORT_API_STATUS_IMPL(KernelDefBuilder_SetExecutionProvider, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* ep_name); +ORT_API_STATUS_IMPL(KernelDefBuilder_SetInputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ size_t input_index, _In_ OrtMemType mem_type); +ORT_API_STATUS_IMPL(KernelDefBuilder_SetOutputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ size_t output_index, _In_ OrtMemType mem_type); +ORT_API_STATUS_IMPL(KernelDefBuilder_AddTypeConstraint, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_ const char* arg_name, _In_reads_(num_types) const OrtDataType* const* types, + _In_ size_t num_types); +ORT_API_STATUS_IMPL(KernelDefBuilder_AddInputOutputAliases, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_reads_(num_io_indices) int const* input_indices, + _In_reads_(num_io_indices) int const* output_indices, + _In_ size_t num_io_indices); +ORT_API_STATUS_IMPL(KernelDefBuilder_AddInputOutputMutableAliases, _In_ OrtKernelDefBuilder* kernel_def_builder, + _In_reads_(num_io_indices) int const* input_indices, + _In_reads_(num_io_indices) int const* output_indices, + _In_ size_t num_io_indices); +ORT_API_STATUS_IMPL(KernelDefBuilder_Build, _In_ OrtKernelDefBuilder* kernel_def_builder, + _Outptr_ OrtKernelDef** kernel_def_out); + +// OrtKernelDef +ORT_API(void, ReleaseKernelDef, _Frees_ptr_opt_ OrtKernelDef* kernel_def); +ORT_API(const char*, KernelDef_GetOperatorType, _In_ const OrtKernelDef* kernel_def); +ORT_API(const char*, KernelDef_GetDomain, _In_ const OrtKernelDef* kernel_def); +ORT_API_STATUS_IMPL(KernelDef_GetSinceVersion, _In_ const OrtKernelDef* kernel_def, + _Out_ int* start_version, _Out_ int* end_version); +ORT_API(const char*, KernelDef_GetExecutionProvider, _In_ const OrtKernelDef* kernel_def); +ORT_API_STATUS_IMPL(KernelDef_GetInputMemType, _In_ const OrtKernelDef* kernel_def, + _In_ size_t input_index, _Out_ OrtMemType* mem_type); +ORT_API_STATUS_IMPL(KernelDef_GetOutputMemType, _In_ const OrtKernelDef* kernel_def, + _In_ size_t output_index, _Out_ OrtMemType* mem_type); + +ORT_API_STATUS_IMPL(GetTensorDataType, _In_ ONNXTensorElementDataType elem_type, + _Outptr_ const OrtDataType** out); +ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc new file mode 100644 index 0000000000000..8dfb0a7ab06b4 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_kernel_registration.h" + +#include +#include + +#include "core/framework/error_code_helper.h" +#include "core/framework/kernel_registry.h" +#include "core/session/plugin_ep/ep_api.h" + +namespace onnxruntime { + +/// +/// OpKernel that wraps a OrtKernelImpl provided by a plugin EP. +/// +class PluginEpOpKernel final : public OpKernel { + private: + struct PrivateTag {}; + + public: + PluginEpOpKernel(const OpKernelInfo& info, PrivateTag) : OpKernel{info} {} // must use ::Create() + + static Status Create(FuncManager& fn_manager, const OpKernelInfo& info, + OrtKernelCreateFunc kernel_create_func, void* kernel_create_func_state, + /*out*/ std::unique_ptr& op_kernel); + + ~PluginEpOpKernel() { + if (kernel_impl_ != nullptr) { + kernel_impl_->Release(kernel_impl_); + } + } + + Status Compute(OpKernelContext* ctx) const override { + assert(kernel_impl_ != nullptr); // Should be ensured by PluginEpOpKernel::Create(). + return ToStatusAndRelease(kernel_impl_->Compute(kernel_impl_, reinterpret_cast(ctx))); + } + + private: + OrtKernelImpl* kernel_impl_ = nullptr; +}; + +/*static*/ +Status PluginEpOpKernel::Create(FuncManager& /*fn_manager*/, const OpKernelInfo& info, + OrtKernelCreateFunc kernel_create_func, void* kernel_create_func_state, + /*out*/ std::unique_ptr& op_kernel) { + // OpKernel's constructor *copies* the OpKernelInfo. + // Therefore, must create the OpKernel instance immediately so that we can pass the actual OpKernelInfo + // to the plugin EP's kernel creation function. + op_kernel = std::make_unique(info, PrivateTag{}); + const OrtKernelInfo* kernel_info = reinterpret_cast(&op_kernel->Info()); + + ORT_RETURN_IF_ERROR(ToStatusAndRelease( + kernel_create_func(kernel_create_func_state, kernel_info, &op_kernel->kernel_impl_))); + ORT_RETURN_IF(op_kernel->kernel_impl_ == nullptr, "OrtKernelCreateFunc returned a NULL OrtKernelImpl"); + + return Status::OK(); +} + +/// +/// A functor that creates a PluginEpOpKernel instance using the creation function (+ state) provided by a plugin EP. +/// +class PluginEpKernelCreateFunctor { + public: + PluginEpKernelCreateFunctor(OrtKernelCreateFunc create_func, void* state) + : kernel_create_func_{create_func}, kernel_create_func_state_{state} {} + + Status operator()(FuncManager& fn_manager, const OpKernelInfo& info, std::unique_ptr& out) { + if (kernel_create_func_ == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "PluginEpKernelCreateFunctor does not wrap a valid OrtKernelCreateFunc"); + } + + std::unique_ptr plugin_ep_op_kernel; + ORT_RETURN_IF_ERROR(PluginEpOpKernel::Create(fn_manager, info, kernel_create_func_, kernel_create_func_state_, + plugin_ep_op_kernel)); + + out = std::move(plugin_ep_op_kernel); + return Status::OK(); + } + + private: + OrtKernelCreateFunc kernel_create_func_ = nullptr; + void* kernel_create_func_state_ = nullptr; +}; + +// Make a KernelCreateInfo for a plugin EP's kernel +KernelCreateInfo MakePluginEpKernelCreateInfo(const KernelDef* kernel_def, + OrtKernelCreateFunc kernel_create_func, + void* kernel_create_func_state) { + auto kernel_def_copy = std::make_unique(*kernel_def); + PluginEpKernelCreateFunctor kernel_create_functor(kernel_create_func, kernel_create_func_state); + return KernelCreateInfo(std::move(kernel_def_copy), kernel_create_functor); +} + +// Copies a const OrtKernelRegistry into a shared_ptr. +static Status CopyEpKernelRegistry(const OrtKernelRegistry* ep_registry, + /*out*/ std::shared_ptr& registry_copy) { + if (ep_registry == nullptr) { + registry_copy = nullptr; + return Status::OK(); + } + + const KernelRegistry* src_registry = reinterpret_cast(ep_registry); + auto dst_registry = std::make_shared(); + + for (const auto& [key, src_create_info] : src_registry->GetKernelCreateMap()) { + auto dst_kernel_def = std::make_unique(*src_create_info.kernel_def); + KernelCreateInfo dst_create_info(std::move(dst_kernel_def), src_create_info.kernel_create_func); + + ORT_RETURN_IF_ERROR(dst_registry->Register(std::move(dst_create_info))); + } + + registry_copy = std::move(dst_registry); + return Status::OK(); +} + +// Gets an OrtEp instance's kernel registry. +Status GetPluginEpKernelRegistry(OrtEp& ort_ep, /*out*/ std::shared_ptr& kernel_registry) { + kernel_registry = nullptr; + + if (ort_ep.ort_version_supported < 24) { + // OrtEp::GetKernelRegistry was added in ORT 1.24.0, but this OrtEp uses an older ORT version. + return Status::OK(); + } + + if (ort_ep.GetKernelRegistry != nullptr) { + const OrtKernelRegistry* ep_registry = nullptr; + + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ort_ep.GetKernelRegistry(&ort_ep, &ep_registry))); + + // ORT needs a shared_ptr due to the IExecutionProvider::GetKernelRegistry() interface. + // We copy the EP's OrtKernelRegistry into a new shared_ptr to ensure the EP fully owns + // the lifetime of the registry it created. + ORT_RETURN_IF_ERROR(CopyEpKernelRegistry(ep_registry, kernel_registry)); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.h b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.h new file mode 100644 index 0000000000000..a7fd1759697df --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/session/onnxruntime_c_api.h" +#include "core/framework/data_types.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/kernel_def_builder.h" +#include "core/framework/kernel_registry.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +/// +/// Make a KernelCreateInfo for a plugin EP's kernel. A KernelCreateInfo contains the function and state +/// necessary to create a kernel. +/// +/// +/// +/// +/// +KernelCreateInfo MakePluginEpKernelCreateInfo(const KernelDef* kernel_def, + OrtKernelCreateFunc kernel_create_func, + void* kernel_create_func_state); + +/// +/// Gets the kernel registry for a plugin EP. +/// +/// The OrtEp instance. +/// Output parameter set to the EP's registry. +/// A status indicating success or an error +Status GetPluginEpKernelRegistry(OrtEp& ort_ep, /*out*/ std::shared_ptr& kernel_registry); + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 55245420db37a..ba6b3c7da9471 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -21,6 +21,7 @@ #include "core/session/abi_logger.h" #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" +#include "core/session/plugin_ep/ep_kernel_registration.h" #include "core/session/ort_apis.h" #include "core/providers/partitioning_utils.h" @@ -43,21 +44,51 @@ PluginExecutionProviderFactory::PluginExecutionProviderFactory(OrtEpFactory& ep_ } } +PluginExecutionProviderFactory::PluginExecutionProviderFactory(OrtEpFactory& ep_factory, + gsl::span ep_devices, + gsl::span hw_devices, + gsl::span ep_metadata) + : ep_factory_{ep_factory}, + devices_{ep_devices.begin(), ep_devices.end()}, + hardware_devices_{hw_devices.begin(), hw_devices.end()}, + ep_metadata_{ep_metadata.begin(), ep_metadata.end()} { +} + std::unique_ptr PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_options, const OrtLogger& session_logger) { - OrtEp* ort_ep = nullptr; - Status status = ToStatusAndRelease(ep_factory_.CreateEp(&ep_factory_, hardware_devices_.data(), ep_metadata_.data(), - hardware_devices_.size(), &session_options, &session_logger, - &ort_ep)); + std::unique_ptr plugin_ep; + Status status = CreatePluginExecutionProvider(session_options, session_logger, plugin_ep); if (!status.IsOK()) { - ORT_THROW("Error creating execution provider: ", status.ToString()); + LOGS(*session_logger.ToInternal(), ERROR) << "Error creating execution provider: " << status.ToString(); + return nullptr; } - return std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)), - session_options, ep_factory_, devices_, - *session_logger.ToInternal()); + return plugin_ep; +} + +Status PluginExecutionProviderFactory::CreatePluginExecutionProvider( + const OrtSessionOptions& session_options, + const OrtLogger& logger, + /*out*/ std::unique_ptr& plugin_ep) { + plugin_ep = nullptr; + OrtEp* ort_ep = nullptr; + + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory_.CreateEp(&ep_factory_, hardware_devices_.data(), + ep_metadata_.data(), hardware_devices_.size(), + &session_options, &logger, &ort_ep))); + ORT_RETURN_IF(ort_ep == nullptr, "OrtEpFactory::CreateEp() for '", ep_factory_.GetName(&ep_factory_), + "' returned a NULL OrtEp instance"); + + std::shared_ptr kernel_registry; + ORT_RETURN_IF_ERROR(GetPluginEpKernelRegistry(*ort_ep, kernel_registry)); + + plugin_ep = std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)), + session_options, ep_factory_, devices_, + kernel_registry, + *logger.ToInternal()); + return Status::OK(); } /// @@ -132,11 +163,13 @@ static const Node* FindFirstNodeAssignedToOtherEP(const std::string& ep_type, PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, gsl::span ep_devices, + std::shared_ptr kernel_registry, const logging::Logger& logger) : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(ep_devices), logger), ort_ep_(std::move(ep)), ep_factory_(ep_factory), - ep_devices_(ep_devices.begin(), ep_devices.end()) { + ep_devices_(ep_devices.begin(), ep_devices.end()), + kernel_registry_(std::move(kernel_registry)) { generate_ep_ctx_model_ = session_options.value.GetEpContextGenerationOptions().enable; for (const auto* ep_device : ep_devices_) { @@ -155,12 +188,16 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio } PluginExecutionProvider::~PluginExecutionProvider() { - if (ort_ep_ && !api_node_compute_infos_.empty()) { + if (ort_ep_ && !api_node_compute_infos_.empty() && ort_ep_->ReleaseNodeComputeInfos != nullptr) { ort_ep_->ReleaseNodeComputeInfos(ort_ep_.get(), api_node_compute_infos_.data(), api_node_compute_infos_.size()); } } +std::shared_ptr PluginExecutionProvider::GetKernelRegistry() const { + return kernel_registry_; +} + std::vector> PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, @@ -168,7 +205,6 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie IResourceAccountant* resource_accountant) const { ORT_UNUSED_PARAMETER(graph_optimizer_registry); // TODO: Add support ORT_UNUSED_PARAMETER(resource_accountant); // TODO: Add support? Not used by prioritized EPs - ORT_UNUSED_PARAMETER(kernel_lookup); // TODO: Add support? Not used by prioritized EPs, so probably not needed? const logging::Logger& logger = GetLogger() != nullptr ? *GetLogger() : logging::LoggingManager::DefaultLogger(); @@ -178,7 +214,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return {}; } - OrtEpGraphSupportInfo api_graph_support_info(*ep_graph); + OrtEpGraphSupportInfo api_graph_support_info(*ep_graph, kernel_lookup); Status status = ToStatusAndRelease(ort_ep_->GetCapability(ort_ep_.get(), ep_graph->ToExternal(), &api_graph_support_info)); if (!status.IsOK()) { @@ -377,7 +413,11 @@ static Status ConvertEpContextNodes(const std::string& ep_name, const std::vecto Status PluginExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_infos) { - const logging::Logger* logger = GetLogger(); + ORT_RETURN_IF(ort_ep_->Compile == nullptr, "OrtEp for ", Type(), " did not provide a valid Compile() function"); + ORT_RETURN_IF(ort_ep_->ReleaseNodeComputeInfos == nullptr, "OrtEp for ", Type(), + " did not provide a valid ReleaseNodeComputeInfos() function"); + + const logging::Logger& logger = GetLogger() != nullptr ? *GetLogger() : logging::LoggingManager::DefaultLogger(); const size_t num_graphs = fused_nodes_and_graphs.size(); std::vector> api_graphs_holder; std::vector api_graphs; @@ -443,16 +483,16 @@ Status PluginExecutionProvider::Compile(const std::vector& fu "instance for graph at index ", i); NodeComputeInfo compute_info; - compute_info.create_state_func = [api_node_compute_info, logger](ComputeContext* context, - FunctionState* compute_state) -> int { + compute_info.create_state_func = [api_node_compute_info, &logger](ComputeContext* context, + FunctionState* compute_state) -> int { Status status = ToStatusAndRelease( api_node_compute_info->CreateState(api_node_compute_info, reinterpret_cast(context), compute_state)); const bool success = status.IsOK(); if (!success) { - LOGS(*logger, ERROR) << "OrtNodeComputeInfo::CreateComputeState() failed with error: " - << status.ErrorMessage(); + LOGS(logger, ERROR) << "OrtNodeComputeInfo::CreateState() failed with error: " + << status.ErrorMessage(); } return success ? 0 : 1; @@ -680,7 +720,7 @@ void PluginExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistr ORT_ENFORCE(status == nullptr && stream != nullptr, "Error creating sync stream for device: ", ToStatusAndRelease(status).ToString()); - return std::make_unique(device, *stream, *GetLogger()); + return std::make_unique(device, *stream, *GetLogger()->ToExternal()); }); registry.RegisterWaitFn(device_type, device_type, plugin_ep::Notification::WaitNotificationOnDevice); diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index 622bbb3f97b24..4fb42d8cf8484 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -10,6 +10,7 @@ #include #include "core/common/common.h" +#include "core/common/inlined_containers.h" #include "core/framework/execution_provider.h" #include "core/providers/providers.h" #include "core/session/onnxruntime_c_api.h" @@ -18,6 +19,7 @@ namespace onnxruntime { struct EpNode; struct EpValueInfo; class NodeArg; +class PluginExecutionProvider; /// /// IExecutionProviderFactory that wraps a OrtEpFactory. Required for SessionOptionsAppendExecutionProvider_V2. @@ -26,6 +28,14 @@ struct PluginExecutionProviderFactory : public IExecutionProviderFactory { public: PluginExecutionProviderFactory(OrtEpFactory& ep_factory, gsl::span ep_devices); + // Constructor that accepts hw devices and ep metadata that have already been extracted from the given OrtEpDevice + // instances. It is an error to call this constructor with hw devices or ep metadata that do not correspond to the + // correct EP devices (e.g., hw_devices[i] and ep_metadata[i] should be extracted from ep_devices[i]). + PluginExecutionProviderFactory(OrtEpFactory& ep_factory, + gsl::span ep_devices, + gsl::span hw_devices, + gsl::span ep_metadata); + std::unique_ptr CreateProvider(const OrtSessionOptions& session_options, const OrtLogger& session_logger) override; @@ -33,11 +43,22 @@ struct PluginExecutionProviderFactory : public IExecutionProviderFactory { ORT_NOT_IMPLEMENTED("CreateProvider without parameters is not supported."); } + /// + /// Alternative version of CreateProvider that returns a Status. + /// + /// The session options to pass to the EP factory. + /// The session logger. Stored by the OrtEp. + /// Output parameter set to the newly created PluginExecutionProvider. + /// A status indicating success or an error. + Status CreatePluginExecutionProvider(const OrtSessionOptions& session_options, + const OrtLogger& logger, + /*out*/ std::unique_ptr& plugin_ep); + private: OrtEpFactory& ep_factory_; - std::vector devices_; - std::vector hardware_devices_; - std::vector ep_metadata_; + InlinedVector devices_; + InlinedVector hardware_devices_; + InlinedVector ep_metadata_; }; /// @@ -65,9 +86,13 @@ class PluginExecutionProvider : public IExecutionProvider { public: explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, - gsl::span ep_devices, const logging::Logger& logger); + gsl::span ep_devices, + std::shared_ptr kernel_registry, + const logging::Logger& logger); ~PluginExecutionProvider(); + std::shared_ptr GetKernelRegistry() const override; + std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, @@ -136,5 +161,7 @@ class PluginExecutionProvider : public IExecutionProvider { // calls IExecutionProvider::GetEpContextNodes(). std::vector> ep_context_nodes_; std::vector> ep_context_node_args_; + + std::shared_ptr kernel_registry_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 7195bfbc77bab..0620d1991a258 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -168,6 +168,25 @@ struct TensorShapeProto_Dimension_Iterator_Impl : TensorShapeProto_Dimension_Ite google::protobuf::internal::RepeatedPtrIterator v_; }; +struct TensorProto_ConstIterator_Impl : TensorProto_ConstIterator { + explicit TensorProto_ConstIterator_Impl(google::protobuf::internal::RepeatedPtrIterator&& v) : v_{std::move(v)} {} + + bool operator!=(const TensorProto_ConstIterator& p) const override { return v_ != static_cast(&p)->v_; } + + void operator++() override { v_.operator++(); } + const ONNX_NAMESPACE::TensorProto& operator*() const override { return *v_; } + + google::protobuf::internal::RepeatedPtrIterator v_; +}; + +struct TensorProto_Iterator_Impl : TensorProto_Iterator { + explicit TensorProto_Iterator_Impl(google::protobuf::internal::RepeatedPtrIterator&& v) : v_{std::move(v)} {} + bool operator!=(const TensorProto_Iterator& p) const override { return v_ != reinterpret_cast(&p)->v_; } + void operator++() override { v_.operator++(); } + ONNX_NAMESPACE::TensorProto& operator*() const override { return *v_; } + google::protobuf::internal::RepeatedPtrIterator v_; +}; + struct NodeAttributes_Iterator_Impl : NodeAttributes_Iterator { NodeAttributes_Iterator_Impl(NodeAttributes::const_iterator&& v) : v_{std::move(v)} {} @@ -594,7 +613,14 @@ struct ProviderHostImpl : ProviderHost { std::string* GraphProto__mutable_name(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_name(); } ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) override { return p->mutable_node(index); } - void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) override { *p = v; } + ONNX_NAMESPACE::GraphProto& GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) override { + *p = v; + return *p; + } + ONNX_NAMESPACE::GraphProto& GraphProto__operator_move_assign(ONNX_NAMESPACE::GraphProto* p, ONNX_NAMESPACE::GraphProto&& v) override { + *p = std::move(v); + return *p; + } void GraphProto__set_name(ONNX_NAMESPACE::GraphProto* p, const std::string& name) override { p->set_name(name); } void GraphProto__set_doc_string(ONNX_NAMESPACE::GraphProto* p, const std::string& doc_str) override { @@ -633,7 +659,14 @@ struct ProviderHostImpl : ProviderHost { // TensorProto (wrapped) std::unique_ptr TensorProto__construct() override { return std::make_unique(); } void TensorProto__operator_delete(ONNX_NAMESPACE::TensorProto* p) override { delete p; } - void TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) override { *p = v; } + ONNX_NAMESPACE::TensorProto& TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) override { + *p = v; + return *p; + } + ONNX_NAMESPACE::TensorProto& TensorProto__operator_move_assign(ONNX_NAMESPACE::TensorProto* p, ONNX_NAMESPACE::TensorProto&& v) override { + *p = std::move(v); + return *p; + } bool TensorProto__has_name(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_name(); } void TensorProto__set_name(ONNX_NAMESPACE::TensorProto* p, const ::std::string& name) override { p->set_name(name); } const ::std::string& TensorProto__name(const ONNX_NAMESPACE::TensorProto* p) override { return p->name(); } @@ -663,8 +696,20 @@ struct ProviderHostImpl : ProviderHost { // TensorProtos (wrapped) ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) override { return p->Add(); } - int TensorProtos__size(ONNX_NAMESPACE::TensorProtos* p) override { return p->size(); } + int TensorProtos__size(const ONNX_NAMESPACE::TensorProtos* p) override { return p->size(); } ONNX_NAMESPACE::TensorProto& TensorProtos__at(ONNX_NAMESPACE::TensorProtos* p, int index) override { return p->at(index); }; + std::unique_ptr TensorProtos__begin(const ONNX_NAMESPACE::TensorProtos* p) override { + return std::make_unique(p->begin()); + } + std::unique_ptr TensorProtos__end(const ONNX_NAMESPACE::TensorProtos* p) override { + return std::make_unique(p->end()); + } + std::unique_ptr TensorProtos__begin(ONNX_NAMESPACE::TensorProtos* p) override { + return std::make_unique(p->begin()); + } + std::unique_ptr TensorProtos__end(ONNX_NAMESPACE::TensorProtos* p) override { + return std::make_unique(p->end()); + } // TensorShapeProto_Dimension (wrapped) int TensorShapeProto_Dimension__value_case(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->value_case(); } diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 6bcbda0f13b92..aa2859985a479 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -7,6 +7,7 @@ #include #include +#include #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" @@ -357,12 +358,12 @@ Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, Or info.hardware_devices.size(), &options, &logger, &ep))); } else { - OrtEp* api_ep = nullptr; - ORT_RETURN_IF_ERROR(ToStatusAndRelease( - info.ep_factory->CreateEp(info.ep_factory, info.hardware_devices.data(), info.ep_metadata.data(), - info.hardware_devices.size(), &options, &logger, &api_ep))); - ep = std::make_unique(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)), options, - *info.ep_factory, info.devices, *logger.ToInternal()); + PluginExecutionProviderFactory factory(*info.ep_factory, info.devices, info.hardware_devices, info.ep_metadata); + std::unique_ptr plugin_ep; + + ORT_RETURN_IF_ERROR(factory.CreatePluginExecutionProvider(options, logger, plugin_ep)); + + ep = std::move(plugin_ep); } return Status::OK(); diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 48d52ae3cf428..e2ab0036c238f 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -628,6 +628,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, return CreateNotEnabledStatus("VitisAI"); } #endif + ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* provider_options) { ORT_UNUSED_PARAMETER(options); diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 6189e6ca7f012..4cb21b80109c8 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -404,6 +404,7 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model session))); } + Env::Default().GetTelemetryProvider().LogCompileModel(session->GetCurrentSessionId()); ORT_RETURN_IF_ERROR(ToStatusAndRelease(InitializeSession(session_options, *session))); return Status::OK(); } diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 91216473bcad2..7b4f130cc2b93 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -547,16 +547,6 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: self._fallback_providers = ["CPUExecutionProvider"] - # MIGraphX can fall back to ROCM if it's explicitly assigned. All others fall back to CPU. - elif "MIGraphXExecutionProvider" in available_providers: - if providers and any( - provider == "ROCMExecutionProvider" - or (isinstance(provider, tuple) and provider[0] == "ROCMExecutionProvider") - for provider in providers - ): - self._fallback_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] - else: - self._fallback_providers = ["CPUExecutionProvider"] else: self._fallback_providers = ["CPUExecutionProvider"] diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 1934e0eda7956..14330655e1ecc 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -154,14 +154,6 @@ OrtMemoryInfo GetMemoryInfoPerDeviceType(const OrtDevice& ort_device) { mem_info = GetCudaAllocator(ort_device.Id())->Info(); } #endif -#if USE_ROCM - else if (ort_device.Type() == OrtDevice::GPU) { - if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), ort_device.Id())) { - ORT_THROW("The provided device id doesn't match any available GPUs on the machine: ", ort_device.Id()); - } - mem_info = GetRocmAllocator(ort_device.Id())->Info(); - } -#endif #if USE_MIGRAPHX else if (ort_device.Type() == OrtDevice::GPU) { mem_info = GetMIGraphXAllocator(ort_device.Id())->Info(); @@ -440,55 +432,6 @@ AllocatorPtr GetCannAllocator(OrtDevice::DeviceId id) { #endif -#ifdef USE_ROCM -void CpuToRocmMemCpy(void* dst, const void* src, size_t num_bytes) { - GetProviderInfo_ROCM().rocmMemcpy_HostToDevice(dst, src, num_bytes); -} - -void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { - GetProviderInfo_ROCM().rocmMemcpy_DeviceToHost(dst, src, num_bytes); -} - -const std::unordered_map* GetRocmToHostMemCpyFunction(const OrtDevice& device) { - static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, RocmToCpuMemCpy}, - }; - - return ↦ -} - -bool IsRocmDeviceIdValid(const onnxruntime::logging::Logger& logger, int id) { - int num_devices = GetProviderInfo_ROCM().hipGetDeviceCount(); - - if (0 == num_devices) { - LOGS(logger, WARNING) << "your system does not have a ROCM capable device."; - return false; - } - - if (id < 0 || id >= num_devices) { - LOGS(logger, WARNING) << "rocm_device=" << id << " is invalid, must choose device ID between 0 and " << num_devices - 1; - return false; - } - - return true; -} - -AllocatorPtr GetRocmAllocator(OrtDevice::DeviceId id) { - // Current approach is not thread-safe, but there are some bigger infra pieces to put together in order to make - // multi-threaded ROCM allocation work we need to maintain a per-thread ROCM allocator - - static auto* id_to_allocator_map = new std::unordered_map(); - - if (id_to_allocator_map->find(id) == id_to_allocator_map->end()) { - // TODO: Expose knobs so that users can set fields associated with OrtArenaCfg so that we can pass it to the following method - id_to_allocator_map->insert({id, GetProviderInfo_ROCM().CreateRocmAllocator(id, gpu_mem_limit, arena_extend_strategy, external_allocator_info, nullptr)}); - } - - return (*id_to_allocator_map)[id]; -} - -#endif - int OnnxRuntimeTensorToNumpyType(const DataTypeImpl* tensor_type) { static std::map type_map{ {DataTypeImpl::GetType(), NPY_BOOL}, diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.h b/onnxruntime/python/onnxruntime_pybind_mlvalue.h index eba783d826212..377122a8bf73e 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.h +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.h @@ -122,20 +122,6 @@ AllocatorPtr GetCannAllocator(OrtDevice::DeviceId id); #endif -#ifdef USE_ROCM - -bool IsRocmDeviceIdValid(const onnxruntime::logging::Logger& logger, int id); - -AllocatorPtr GetRocmAllocator(OrtDevice::DeviceId id); - -void CpuToRocmMemCpy(void* dst, const void* src, size_t num_bytes); - -void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes); - -const std::unordered_map* GetRocmToHostMemCpyFunction(const OrtDevice&); - -#endif - void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, const AllocatorPtr& alloc, const std::string& name_input, const pybind11::object& value, OrtValue* p_mlvalue, bool accept_only_numpy_array = false, bool use_numpy_data_memory = true, diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index d74663ddb63d7..f996bf213b4a0 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -47,13 +47,7 @@ std::unique_ptr OrtValueFromShapeAndType(const std::vector& s "Please use the CUDA package of OnnxRuntime to use this feature."); #endif } else if (strcmp(GetDeviceName(device), HIP) == 0) { -#if USE_ROCM - if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { - throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); - } - - allocator = GetRocmAllocator(device.Id()); -#elif USE_MIGRAPHX +#if USE_MIGRAPHX allocator = GetMIGraphXAllocator(device.Id()); #else throw std::runtime_error( @@ -125,20 +119,6 @@ void addOrtValueMethods(pybind11::module& m) { true, false, CpuToCudaMemCpy); } else #endif -#ifdef USE_ROCM - if (device.Vendor() == OrtDevice::VendorIds::AMD) { - if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { - throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); - } - - // InputDeflist is null because OrtValue creation is not tied to a specific model - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors - // in ROCM - CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), - true, false, CpuToRocmMemCpy); - } else -#endif #if USE_MIGRAPHX if (device.Vendor() == OrtDevice::VendorIds::AMD) { // InputDeflist is null because OrtValue creation is not tied to a specific model @@ -212,19 +192,6 @@ void addOrtValueMethods(pybind11::module& m) { CpuToCudaMemCpy); } else #endif -#if USE_ROCM - if (device.Vendor() == OrtDevice::VendorIds::AMD) { - if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { - throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); - } - - onnxruntime::python::CopyDataToTensor( - py_values, - values_type, - *(ml_value->GetMutable()), - CpuToRocmMemCpy); - } else -#endif #if USE_MIGRAPHX if (device.Vendor() == OrtDevice::VendorIds::AMD) { onnxruntime::python::CopyDataToTensor( diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index cd1d2a8da10aa..8cb617fe5226c 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -29,12 +29,6 @@ void addGlobalSchemaFunctions(pybind11::module& m) { return CudaProviderFactoryCreator::Create(&provider_options); }(), #endif -#ifdef USE_ROCM - []() { - OrtROCMProviderOptions provider_options; - return onnxruntime::RocmProviderFactoryCreator::Create(&provider_options); - }(), -#endif #ifdef USE_DNNL onnxruntime::DnnlProviderFactoryCreator::Create(1), #endif diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 92cf6b085c01e..24487c4a7b844 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -509,27 +509,6 @@ const CANNExecutionProviderInfo GetCannExecutionProviderInfo(ProviderInfo_CANN* } #endif -#ifdef USE_ROCM -const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* rocm_provider_info, - const ProviderOptionsMap& provider_options_map) { - ORT_ENFORCE(rocm_provider_info); - const auto it = provider_options_map.find(kRocmExecutionProvider); - ROCMExecutionProviderInfo info; - if (it != provider_options_map.end()) - rocm_provider_info->ROCMExecutionProviderInfo__FromProviderOptions(it->second, info); - else { - info.device_id = cuda_device_id; - info.gpu_mem_limit = gpu_mem_limit; - info.arena_extend_strategy = arena_extend_strategy; - info.miopen_conv_exhaustive_search = miopen_conv_exhaustive_search; - info.do_copy_in_default_stream = do_copy_in_default_stream; - info.external_allocator_info = external_allocator_info; - info.tunable_op = tunable_op; - } - return info; -} -#endif - #if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOptions& options) { if (auto* tensorrt_provider_info = TryGetProviderInfo_TensorRT()) { @@ -1029,26 +1008,6 @@ static std::shared_ptr CreateExecutionProviderFactory "make sure they're in the PATH, and that your GPU is supported."; #endif // defined(USE_CUDA) #endif // defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) - } else if (type == kRocmExecutionProvider) { -#ifdef USE_ROCM - if (auto* rocm_provider_info = TryGetProviderInfo_ROCM()) { - const ROCMExecutionProviderInfo info = GetRocmExecutionProviderInfo(rocm_provider_info, - provider_options_map); - - // This variable is never initialized because the APIs by which is it should be initialized are deprecated, - // however they still exist and are in-use. Nevertheless, it is used to return ROCMAllocator, hence we must - // try to initialize it here if we can since FromProviderOptions might contain external ROCM allocator. - external_allocator_info = info.external_allocator_info; - return rocm_provider_info->CreateExecutionProviderFactory(info); - } else { - if (!Env::Default().GetEnvironmentVar("ROCM_PATH").empty()) { - ORT_THROW( - "ROCM_PATH is set but ROCM wasn't able to be loaded. Please install the correct version " - "of ROCM and MIOpen as mentioned in the GPU requirements page, make sure they're in the PATH, " - "and that your GPU is supported."); - } - } -#endif } else if (type == kDnnlExecutionProvider) { #ifdef USE_DNNL // Generate dnnl_options @@ -1475,7 +1434,7 @@ bool CheckIfTensor(const std::vector& def_list, } #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) || \ - defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || defined(USE_ROCM) + defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) static void LogDeprecationWarning( const std::string& deprecated, const optional& alternative = nullopt) { LOGS_DEFAULT(WARNING) << "This is DEPRECATED and will be removed in the future: " << deprecated; @@ -1629,7 +1588,7 @@ void addGlobalMethods(py::module& m) { "Gets the dynamically selected OpenVINO device type for inference."); #endif -#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) /* * The following set_* methods are deprecated. * @@ -1639,40 +1598,30 @@ void addGlobalMethods(py::module& m) { */ // TODO remove deprecated global config m.def("set_cuda_device_id", [](const int id) { - LogDeprecationWarning("set_cuda_device_id", "CUDA/ROCM execution provider option \"device_id\""); + LogDeprecationWarning("set_cuda_device_id", "CUDA execution provider option \"device_id\""); cuda_device_id = static_cast(id); }); // TODO remove deprecated global config m.def("set_cudnn_conv_algo_search", [](const OrtCudnnConvAlgoSearch algo) { LogDeprecationWarning("set_cudnn_conv_algo_search", "CUDA execution provider option \"cudnn_conv_algo_search\""); -#ifdef USE_ROCM - ORT_UNUSED_PARAMETER(algo); - ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM"); -#else cudnn_conv_algo_search = algo; -#endif }); // TODO remove deprecated global config m.def("set_do_copy_in_default_stream", [](const bool use_single_stream) { LogDeprecationWarning( "set_do_copy_in_default_stream", "CUDA execution provider option \"do_copy_in_default_stream\""); -#ifdef USE_ROCM - ORT_UNUSED_PARAMETER(use_single_stream); - ORT_THROW("set_do_copy_in_default_stream is not supported in ROCM"); -#else do_copy_in_default_stream = use_single_stream; -#endif }); // TODO remove deprecated global config m.def("set_gpu_mem_limit", [](const int64_t limit) { LogDeprecationWarning( "set_gpu_mem_limit", - "CUDA execution provider option \"gpu_mem_limit\", ROCM execution provider option \"gpu_mem_limit\""); + "CUDA execution provider option \"gpu_mem_limit\""); gpu_mem_limit = gsl::narrow(limit); }); // TODO remove deprecated global config m.def("set_arena_extend_strategy", [](const onnxruntime::ArenaExtendStrategy strategy) { - LogDeprecationWarning("set_arena_extend_strategy", "CUDA/ROCM execution provider option \"arena_extend_strategy\""); + LogDeprecationWarning("set_arena_extend_strategy", "CUDA execution provider option \"arena_extend_strategy\""); arena_extend_strategy = strategy; }); #endif @@ -1825,7 +1774,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra } else if (type == OrtDevice::GPU) { #if USE_CUDA || USE_NV || USE_NV_PROVIDER_INTERFACE || USE_CUDA_PROVIDER_INTERFACE vendor = OrtDevice::VendorIds::NVIDIA; -#elif USE_ROCM || USE_MIGRAPHX +#elif USE_MIGRAPHX vendor = OrtDevice::VendorIds::AMD; #endif } else if (type == OrtDevice::NPU) { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.cc b/onnxruntime/python/onnxruntime_pybind_state_common.cc index cccdb9d23900a..0d00f3c4e6eb0 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.cc +++ b/onnxruntime/python/onnxruntime_pybind_state_common.cc @@ -31,17 +31,7 @@ onnxruntime::CUDAExecutionProviderExternalAllocatorInfo external_allocator_info{ onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo; #endif -#ifdef USE_ROCM -// TODO remove deprecated global config -bool miopen_conv_exhaustive_search = false; -// TODO remove deprecated global config -bool do_copy_in_default_stream = true; -// TODO remove deprecated global config -onnxruntime::rocm::TunableOpInfo tunable_op{}; -onnxruntime::ROCMExecutionProviderExternalAllocatorInfo external_allocator_info{}; -#endif - -#if defined(USE_ROCM) || defined(USE_MIGRAPHX) +#if defined(USE_MIGRAPHX) // TODO remove deprecated global config onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo; #endif diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index b4a33e798f942..30ca76877dd0d 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -34,7 +34,7 @@ struct OrtStatus { #include "core/providers/tensorrt/tensorrt_provider_options.h" #include "core/providers/nv_tensorrt_rtx/nv_provider_options.h" -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) #define BACKEND_PROC "GPU" #else #define BACKEND_PROC "CPU" @@ -122,10 +122,6 @@ struct OrtStatus { #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cuda/cuda_execution_provider_info.h" #endif -#ifdef USE_ROCM -#include "core/providers/rocm/rocm_provider_factory.h" -#include "core/providers/rocm/rocm_execution_provider_info.h" -#endif #if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) #include "core/providers/tensorrt/tensorrt_provider_factory.h" #endif @@ -198,23 +194,7 @@ ProviderInfo_CANN& GetProviderInfo_CANN(); } // namespace onnxruntime #endif -#ifdef USE_ROCM -namespace onnxruntime { -ProviderInfo_ROCM* TryGetProviderInfo_ROCM(); -ProviderInfo_ROCM& GetProviderInfo_ROCM(); -namespace python { -// TODO remove deprecated global config -extern bool miopen_conv_exhaustive_search; -// TODO remove deprecated global config -extern bool do_copy_in_default_stream; -// TODO remove deprecated global config -extern onnxruntime::rocm::TunableOpInfo tunable_op; -extern onnxruntime::ROCMExecutionProviderExternalAllocatorInfo external_allocator_info; -} // namespace python -} // namespace onnxruntime -#endif - -#if defined(USE_ROCM) || defined(USE_MIGRAPHX) +#if defined(USE_MIGRAPHX) namespace onnxruntime { namespace python { extern onnxruntime::ArenaExtendStrategy arena_extend_strategy; diff --git a/onnxruntime/python/tools/microbench/benchmark.py b/onnxruntime/python/tools/microbench/benchmark.py index a5936afcfe13e..257548b612c73 100644 --- a/onnxruntime/python/tools/microbench/benchmark.py +++ b/onnxruntime/python/tools/microbench/benchmark.py @@ -32,12 +32,12 @@ def add_arguments(parser: ArgumentParser): "--provider", required=False, type=str, - choices=["cuda", "rocm", "cpu", None], + choices=["cuda", "cpu", None], default=None, help=( "Execution provider to use. By default, a " "provider is selected in the priority order " - "(cuda|rocm, cpu) depending on availability." + "(cuda, cpu) depending on availability." ), ) parser.add_argument( @@ -60,7 +60,6 @@ def add_arguments(parser: ArgumentParser): def provider_name(name): provider_map = { "cuda": "CUDAExecutionProvider", - "rocm": "ROCMExecutionProvider", "cpu": "CPUExecutionProvider", } return provider_map[name] @@ -69,8 +68,6 @@ def provider_name(name): def get_default_provider(): if "CUDAExecutionProvider" in ort.get_available_providers(): return "CUDAExecutionProvider" - if "ROCMExecutionProvider" in ort.get_available_providers(): - return "ROCMExecutionProvider" return "CPUExecutionProvider" @@ -85,7 +82,7 @@ def __init__(self, model, inputs, outputs, args): self.outputs = outputs def create_input_output_tensors(self): - on_gpu = self.provider == "CUDAExecutionProvider" or self.provider == "ROCMExecutionProvider" + on_gpu = self.provider == "CUDAExecutionProvider" device = "cuda" if on_gpu else "cpu" input_tensors = {name: torch.from_numpy(array).to(device) for name, array in self.inputs.items()} output_tensors = {name: torch.from_numpy(array).to(device) for name, array in self.outputs.items()} diff --git a/onnxruntime/python/tools/quantization/shape_inference.py b/onnxruntime/python/tools/quantization/shape_inference.py index c588689187383..cc3bc2ef28c4f 100644 --- a/onnxruntime/python/tools/quantization/shape_inference.py +++ b/onnxruntime/python/tools/quantization/shape_inference.py @@ -74,34 +74,29 @@ def quant_pre_process( with tempfile.TemporaryDirectory(prefix="pre.quant.") as quant_tmp_dir: temp_path = Path(quant_tmp_dir) - model = None + model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model) + + # Since Upsample is deprecated after opset v10, and the model's opset will + # be upgraded to at least v11 during quantization, we need to replace Upsample + # with Resize first to avoid generating an invalid model. + ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"] + if len(ai_onnx_domain) == 1: + opset_version = ai_onnx_domain[0].version + if opset_version <= 10: + ReplaceUpsampleWithResize(ONNXModel(model), opset_version).apply() + model = onnx.version_converter.convert_version(model, 11) + model = save_and_reload_model_with_shape_infer(model) if not skip_symbolic_shape: logger.info("Performing symbolic shape inference...") - loaded_model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model) model = SymbolicShapeInference.infer_shapes( - loaded_model, + model, int_max, auto_merge, guess_output_rank, verbose, ) - # Since Upsample is deprecated after opset v10, and the model's opset will - # be upgraded to at least v11 during quantization, we need to replace Upsample - # with Resize first to avoid generating an invalid model. - if model: - ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"] - if len(ai_onnx_domain) == 1: - opset_version = ai_onnx_domain[0].version - if opset_version < 10: - ReplaceUpsampleWithResize(ONNXModel(model), opset_version).apply() - model.opset_import.remove(ai_onnx_domain[0]) - opset_version = 11 - model.opset_import.extend([onnx.helper.make_opsetid("", opset_version)]) - model = onnx.version_converter.convert_version(model, opset_version) - model = save_and_reload_model_with_shape_infer(model) - if not skip_optimization: # Use ORT optimizers (native code) to optimize model if not skip_symbolic_shape: diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 77a9e31b6208f..eb29080734b40 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -34,8 +34,6 @@ python benchmark.py -e torchscript -g -p "fp16" Run ONNXRuntime and TorchScript on CPU for all models with quantization: python benchmark.py -e torchscript onnxruntime -p "int8" -o - Run OnnxRuntime with the ROCM provider and graph optimization script: - python benchmark.py -g -m bert-base-cased --provider rocm --optimizer_info by_script --disable_embed_layer_norm Run OnnxRuntime with bfloat16 fastmath mode kernels on aarch64 platforms with bfloat16 support: python benchmark.py --enable_arm64_bfloat16_fastmath_mlas_gemm @@ -118,7 +116,6 @@ def run_onnxruntime( use_gpu and ("CUDAExecutionProvider" not in onnxruntime.get_available_providers()) and ("MIGraphXExecutionProvider" not in onnxruntime.get_available_providers()) - and ("ROCMExecutionProvider" not in onnxruntime.get_available_providers()) and ("DmlExecutionProvider" not in onnxruntime.get_available_providers()) ): logger.error( @@ -788,7 +785,7 @@ def main(): logger.error("fp16 is for GPU only") return - if args.precision == Precision.INT8 and args.use_gpu and args.provider not in ["migraphx", "rocm"]: + if args.precision == Precision.INT8 and args.use_gpu and args.provider not in ["migraphx"]: logger.error("int8 is for CPU only") return diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index a6716c8df3bc2..8055e5e4ae876 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -112,12 +112,9 @@ def create_onnxruntime_session( elif use_gpu: if provider == "dml": providers = ["DmlExecutionProvider", "CPUExecutionProvider"] - elif provider == "rocm": - providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] elif provider == "migraphx": providers = [ "MIGraphXExecutionProvider", - "ROCMExecutionProvider", "CPUExecutionProvider", ] elif provider == "cuda" or provider is None: @@ -174,8 +171,8 @@ def prepare_environment(cache_dir, output_dir, use_gpu, provider=None): else: assert not set(onnxruntime.get_available_providers()).isdisjoint( - ["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"] - ), "Please install onnxruntime-gpu package, or install ROCm support, to test GPU inference." + ["CUDAExecutionProvider", "MIGraphXExecutionProvider"] + ), "Please install onnxruntime-gpu package, or install migraphx, to test GPU inference." logger.info(f"PyTorch Version:{torch.__version__}") logger.info(f"Transformers Version:{transformers.__version__}") diff --git a/onnxruntime/python/tools/transformers/bert_perf_test.py b/onnxruntime/python/tools/transformers/bert_perf_test.py index ebf44e49c89bb..9920a413b699e 100644 --- a/onnxruntime/python/tools/transformers/bert_perf_test.py +++ b/onnxruntime/python/tools/transformers/bert_perf_test.py @@ -80,12 +80,9 @@ def create_session( if use_gpu: if provider == "dml": execution_providers = ["DmlExecutionProvider", "CPUExecutionProvider"] - elif provider == "rocm": - execution_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] elif provider == "migraphx": execution_providers = [ "MIGraphXExecutionProvider", - "ROCMExecutionProvider", "CPUExecutionProvider", ] elif provider == "cuda": @@ -128,11 +125,8 @@ def create_session( if use_gpu: if provider == "dml": assert "DmlExecutionProvider" in session.get_providers() - elif provider == "rocm": - assert "ROCMExecutionProvider" in session.get_providers() elif provider == "migraphx": assert "MIGraphXExecutionProvider" in session.get_providers() - assert "ROCMExecutionProvider" in session.get_providers() elif provider == "cuda": assert "CUDAExecutionProvider" in session.get_providers() elif provider == "tensorrt": diff --git a/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py index f8b7dd80710ae..a4015f50fdc13 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py @@ -96,8 +96,8 @@ def parse_arguments(argv=None): "--provider", required=False, default=None, - choices=["dml", "rocm", "migraphx", "cuda", "tensorrt"], - help="use dml, rocm, cuda, tensorrt or migraphx for respective backend", + choices=["dml", "migraphx", "cuda", "tensorrt"], + help="use dml, cuda, tensorrt or migraphx for respective backend", ) parser.add_argument( diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index 61bfc950735af..dbe4799f20b9c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -584,7 +584,7 @@ def get_args(rank=0): "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", - choices=["cpu", "cuda", "rocm"], + choices=["cpu", "cuda"], ) parser.add_argument("-id", "--device-id", type=int, default=0) parser.add_argument("-w", "--warmup-runs", type=int, default=5) @@ -622,9 +622,6 @@ def get_args(rank=0): setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010 if args.execution_provider == "CUDAExecutionProvider": args.execution_provider = (args.execution_provider, {"device_id": rank}) - elif args.execution_provider == "ROCMExecutionProvider": - args.execution_provider = (args.execution_provider, {"device_id": rank}) - args.device = "cuda" # Check that paths have been specified for any benchmarking with ORT if args.benchmark_type == "hf-ort": diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py index 6447a7322b6ed..059a69e492554 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py @@ -109,7 +109,7 @@ def get_args(): "--device", type=str, required=True, - choices=["cpu", "cuda", "rocm"], + choices=["cpu", "cuda"], help="Device to benchmark models", ) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index aa118da71525a..6411dca00b5de 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -631,7 +631,7 @@ def get_args(): "--execution_provider", required=False, default="cpu", - choices=["cpu", "cuda", "rocm"], + choices=["cpu", "cuda"], help="Execution provider to verify parity with", ) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 383101c8a3b72..f0aa07d3768b6 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -228,7 +228,7 @@ def get_args(argv: list[str]): "--execution_provider", required=False, default="cpu", - choices=["cpu", "cuda", "rocm"], + choices=["cpu", "cuda"], help="Execution provider to verify parity with", ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 2506ffe8a3f50..12e6df53de577 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -158,28 +158,6 @@ pip install -r requirements/cuda12/requirements.txt ``` Finally, `pip install tensorrt` for Linux. For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. -### Setup Environment (ROCm) - -It is recommended that the users run the model with ROCm 6.2 or newer and Python 3.10. You can follow the following to install ROCm 6.x: https://rocmdocs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html -Note that Windows is not supported for ROCm at the moment. - -``` -pip install -r requirements/rocm/requirements.txt -``` - -AMD GPU version of PyTorch can be installed from [pytorch.org](https://pytorch.org/get-started/locally/) or [AMD Radeon repo](https://repo.radeon.com/rocm/manylinux/rocm-rel-6.2.3/). - -#### Install onnxruntime-rocm - -One option is to install prebuilt wheel from https://repo.radeon.com/rocm/manylinux like: -``` -wget https://repo.radeon.com/rocm/manylinux/rocm-rel-6.2.3/onnxruntime_rocm-1.18.0-cp310-cp310-linux_x86_64.whl -pip install onnxruntime_rocm-1.18.0-cp310-cp310-linux_x86_64.whl -``` - -If you want to use latest version of onnxruntime, you can build from source with Rocm 6.x following https://onnxruntime.ai/docs/build/eps.html#amd-rocm. -When the build is finished, you can install the wheel:`pip install build/Linux/Release/dist/*.whl`. - ### Export ONNX pipeline This step will export stable diffusion 1.5 to ONNX model in float32 using script from diffusers. @@ -258,16 +236,6 @@ python benchmark.py -b 1 -v 1.5 For the first command, '-p' specifies a directory of optimized ONNX pipeline as generated by optimize_pipeline.py. For the second command without '-p', we will use ORTPipelineForText2Image to export and optimize ONNX models for clip, unet and vae decoder. -On ROCm EP, use the following command instead: -``` -python benchmark.py -p ./sd1.5_onnx/fp16 -b 1 --tuning --provider rocm -v 1.5 -``` - -For ROCm EP, you can substitute `python benchmark.py` with `python -m onnxruntime.transformers.models.stable_diffusion.benchmark` since -the installed package is built from source. For CUDA, it is recommended to run `python benchmark.py` with the latest benchmark script. - -For ROCm EP, the `--tuning` is mandatory because we heavily rely on tuning to find the runable kernels for ORT `OpKernel`s. - The default parameters are stable diffusion version=1.5, height=512, width=512, steps=50, batch_count=5. Run `python benchmark.py --help` for more information. #### Stable Diffusion 3.x and Flux 1.0 @@ -303,12 +271,6 @@ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu117 python benchmark.py -e torch -b 1 --enable_torch_compile -v 1.5 ``` -For ROCm: -``` -pip install torch --upgrade --index-url https://download.pytorch.org/whl/rocm5.4.2 -python benchmark.py -e torch -b 1 --enable_torch_compile --provider rocm -v 1.5 -``` - Sometime, it complains ptxas not found when there are multiple CUDA versions installed. It can be fixed like `export TRITON_PTXAS_PATH=/usr/local/cuda-11.7/bin/ptxas` before running benchmark. Note that torch.compile is not supported in Windows: we encountered error `Windows not yet supported for torch.compile`. So it is excluded from RTX 3060 results of Windows. @@ -352,65 +314,6 @@ Here FMHA means Attention and MultiHeadAttention operators with Flash Attention The last two optimizations (Packed QKV and BiasAdd) are only available in nightly package. Compared to 1.14.1, nightly package has slight improvement in performance. -### Results on MI250X with 1 GCD - -With runtime tuning enabled, we get following performance number on one GCD of a MI250X GPU: - -| Optimizations | Average Latency (batch_size=1) | Memory in MB (batch_size=1) | Average Latency (batch_size=8) | Memory in MB (batch_size=8) | -| --------------------------------------------------------------------- | ------------------------------ | --------------------------- | ------------------------------ | --------------------------- | -| Raw FP32 models | 6.7 | 17,319 | 36.4 * | 33,787 | -| FP16 baseline | 4.1 | 8,945 | 24.0 * | 34,493 | -| FP16 baseline + FMHA | 2.6 | 4,886 | 15.0 | 10,146 | -| FP16 baseline + FMHA + NhwcConv | 2.4 | 4,952 | 14.8 | 9,632 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm | 2.3 | 4,906 | 13.6 | 9,774 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu | 2.2 | 4,910 | 12.5 | 9,646 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + BiasAdd | 2.2 | 4,910 | 12.5 | 9,778 | - -The entries marked with `*` produce suspicious output images. The might be numerical stability or correctness issue for the pipeline. The performance number is for reference only. - - -### Example Benchmark output - -Common settings for below test results: - -| model_name | disable_safety_checker | height | width | steps | batch_count | num_prompts | -| ------------------------------ | ---------------------- | ------ | ----- | ----- | ----------- | ----------- | -| runwayml/stable-diffusion-v1-5 | TRUE | 512 | 512 | 50 | 5 | 1 | - -#### Results of MI250X, 1 GCD (Ubuntu 20.04) - -| engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | -| ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| onnxruntime | 1.15.0+rocm5.4.2 | ROCM | 1 | 2.2 | 5,548 | 4,908 | -| torch | 1.12.1+rocm5.4 | - | 1 | 3.4 | 6,653 | 4,613 | -| torch | 2.0.0+rocm5.4.2 | default | 1 | 3.2 | 5,977 | 4,368 | -| torch | 2.0.0+rocm5.4.2 | compile | 1 | 3.0 | 5,869 | 4,266 | -| onnxruntime | 1.15.0+rocm5.4.2 | ROCM | 4 | 6.6 | 5,546 | 4,906 | -| torch | 1.12.1+rocm5.4 | - | 4 | 10.1 | 19,477 | 11,325 | -| torch | 2.0.0+rocm5.4.2 | default | 4 | 10.5 | 13,051 | 7,300 | -| torch | 2.0.0+rocm5.4.2 | compile | 4 | 9.2 | 12,879 | 7,190 | -| onnxruntime | 1.15.0+rocm5.4.2 | ROCM | 8 | 12.5 | 9,778 | 9,006 | -| torch | 1.12.1+rocm5.4 | - | 8 | 19.3 | 55,851 | 20,014 | -| torch | 2.0.0+rocm5.4.2 | default | 8 | 20.3 | 23,551 | 11,930 | -| torch | 2.0.0+rocm5.4.2 | compile | 8 | 17.8 | 23,303 | 11,800 | - -#### Results of MI100 (Ubuntu 20.04) - -| engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | -| ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| onnxruntime | 1.15.0+rocm5.4.2 | ROCM | 1 | 2.4 | 5,254 | 4,614 | -| torch | 1.12.1+rocm5.4 | - | 1 | 3.5 | 5,771 | 4,672 | -| torch | 2.0.0+rocm5.4.2 | default | 1 | 3.5 | 5,811 | 4,206 | -| torch | 2.0.0+rocm5.4.2 | compile | 1 | 3.1 | 5,774 | 4,168 | -| onnxruntime | 1.15.0+rocm5.4.2 | ROCM | 4 | 7.5 | 7,290 | 6,646 | -| torch | 1.12.1+rocm5.4 | - | 4 | 10.7 | 19,334 | 11,181 | -| torch | 2.0.0+rocm5.4.2 | default | 4 | 11.5 | 12,881 | 7,151 | -| torch | 2.0.0+rocm5.4.2 | compile | 4 | 10.0 | 12,740 | 7,073 | -| onnxruntime | 1.15.0+rocm5.4.2 | ROCM | 8 | 14.4 | 7,320 | 6,676 | -| torch | 1.12.1+rocm5.4 | - | 8 | 20.2 | 31,820 | 19,908 | -| torch | 2.0.0+rocm5.4.2 | default | 8 | 22.2 | 23,415 | 11,815 | -| torch | 2.0.0+rocm5.4.2 | compile | 8 | 19.3 | 23,154 | 11,667 | - ### Credits Some CUDA kernels (TensorRT Fused Attention, GroupNorm, SplitGelu and BiasAdd etc.) and demo diffusion were originally implemented in [TensorRT](https://github.com/nviDIA/TensorRT) by Nvidia. @@ -418,9 +321,6 @@ We use [Flash Attention v2](https://github.com/Dao-AILab/flash-attention) in Lin We use Memory efficient attention from [CUTLASS](https://github.com/NVIDIA/cutlass). The kernels were developed by Meta xFormers. The ONNX export script and pipeline for stable diffusion was developed by Huggingface [diffusers](https://github.com/huggingface/diffusers) library. -Most ROCm kernel optimizations are from [composable kernel](https://github.com/ROCmSoftwarePlatform/composable_kernel). -Some kernels are enabled by MIOpen. We hereby thank for the AMD developers' collaboration. - ### Future Works * Update demo to support inpainting. * Support flash attention in Windows. diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index b4d2977050bd4..ed2e346972a6c 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -31,7 +31,6 @@ PROVIDERS = { "cuda": "CUDAExecutionProvider", - "rocm": "ROCMExecutionProvider", "migraphx": "MIGraphXExecutionProvider", "tensorrt": "TensorrtExecutionProvider", } @@ -328,7 +327,7 @@ def run_ort( skip_warmup: bool = False, ): provider_and_options = provider - if tuning and provider in ["CUDAExecutionProvider", "ROCMExecutionProvider"]: + if tuning and provider in ["CUDAExecutionProvider"]: provider_and_options = (provider, {"tunable_op_enable": 1, "tunable_op_tuning_enable": 1}) load_start = time.time() @@ -1150,8 +1149,7 @@ def parse_arguments(): "-t", "--tuning", action="store_true", - help="Enable TunableOp and tuning. " - "This will incur longer warmup latency, and is mandatory for some operators of ROCm EP.", + help="Enable TunableOp and tuning. This will incur longer warmup latency.", ) parser.add_argument( @@ -1336,7 +1334,7 @@ def main(): coloredlogs.install(fmt="%(funcName)20s: %(message)s") - memory_monitor_type = "rocm" if args.provider == "rocm" else "cuda" + memory_monitor_type = "cuda" start_memory = measure_gpu_memory(memory_monitor_type, None) print("GPU memory used before loading models:", start_memory) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/rocm/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/rocm/requirements.txt deleted file mode 100644 index 21b100fb61f17..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/rocm/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ --r ../requirements.txt -# Install onnxruntime-rocm that is built from source (https://onnxruntime.ai/docs/build/eps.html#amd-rocm) diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 88fdad01baf92..04b62f4b2da99 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -130,9 +130,6 @@ def get_model(args: argparse.Namespace): if args.verbose: sess_options.log_verbosity_level = 1 sess_options.log_severity_level = 1 - if args.tune: - ort.set_default_logger_severity(0) - ort.set_default_logger_verbosity(0) else: raise Exception(f"Cannot recognize {args.benchmark_type}") @@ -338,9 +335,6 @@ def prepare_ort_inputs(inputs, warmup=False): logger.error(f"The following model inputs are missing: {missing_inputs}") raise Exception("There are missing inputs to the model. Please add them and try again.") - if warmup and args.tune: - inputs["min_length"] = inputs["max_length"] - # Remove unnecessary inputs from model inputs unnecessary_inputs = user_inputs - model_inputs if len(unnecessary_inputs): @@ -392,9 +386,6 @@ def handle_output(output): # ORT evaluation logger.info("\nEvaluating ONNX Runtime...") ort_evaluate_inputs = ort_inputs - if args.tune: - ort_warmup_inputs = prepare_ort_inputs(inputs, warmup=True) - ort_evaluate_inputs = (ort_warmup_inputs, ort_inputs) time_fn(args, generate_fn, ort_evaluate_inputs) ort_outputs = generate_fn(ort_inputs) @@ -479,7 +470,7 @@ def parse_args(): "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", - choices=["cpu", "cuda", "rocm"], + choices=["cpu", "cuda"], ) parser.add_argument("-id", "--device-id", type=int, default=0) parser.add_argument("-w", "--warmup-runs", type=int, default=5) @@ -527,12 +518,6 @@ def parse_args(): parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display") parser.add_argument("--verbose", default=False, action="store_true") parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files") - parser.add_argument( - "--tune", - default=False, - action="store_true", - help="Only used by ROCm EP, enable TunableOp tuning to select fastest kernel", - ) args = parser.parse_args() @@ -546,16 +531,6 @@ def parse_args(): args.execution_provider = f"{args.device.upper()}ExecutionProvider" if args.execution_provider == "CUDAExecutionProvider": args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) - elif args.execution_provider == "ROCMExecutionProvider": - args.execution_provider = ( - args.execution_provider, - { - "device_id": args.device_id, - "tunable_op_enable": 1, - "tunable_op_tuning_enable": 1 if args.tune else 0, - }, - ) - args.device = "cuda" # Check that model paths have been specified for any benchmarking with ORT if args.benchmark_type == "hf-ort": diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index 95d4b60fead99..a5679fbc2c40e 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -105,7 +105,7 @@ def get_args(): "--device", type=str, required=True, - choices=["cpu", "cuda", "rocm"], + choices=["cpu", "cuda"], help="Device to benchmark models", ) diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 38fbd73e9c119..79b508047da55 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -26,7 +26,6 @@ PROVIDERS = { "cpu": "CPUExecutionProvider", "cuda": "CUDAExecutionProvider", - "rocm": "ROCMExecutionProvider", } diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 85b3632c516ca..72c51386dfe9e 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -767,7 +767,7 @@ def optimize_onnx( optimization_options = FusionOptions("bart") optimization_options.use_multi_head_attention = True - optimization_options.disable_multi_head_attention_bias = provider == "rocm" + optimization_options.disable_multi_head_attention_bias = False m = optimize_model( onnx_model_path, diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 32a25ef1420ba..f4e8bcbe9103e 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -111,7 +111,7 @@ def optimize_by_onnxruntime( use_gpu and provider is None and set(onnxruntime.get_available_providers()).isdisjoint( - ["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"] + ["CUDAExecutionProvider", "MIGraphXExecutionProvider"] ) ): logger.error("There is no gpu for onnxruntime to do optimization.") @@ -172,10 +172,8 @@ def optimize_by_onnxruntime( elif provider is not None: if provider == "dml": providers = ["DmlExecutionProvider"] - elif provider == "rocm": - providers = ["ROCMExecutionProvider"] elif provider == "migraphx": - providers = ["MIGraphXExecutionProvider", "ROCMExecutionProvider"] + providers = ["MIGraphXExecutionProvider"] elif provider == "cuda": providers = ["CUDAExecutionProvider"] elif provider == "tensorrt": @@ -189,7 +187,6 @@ def optimize_by_onnxruntime( if torch_version.hip: providers.append("MIGraphXExecutionProvider") - providers.append("ROCMExecutionProvider") else: providers.append("CUDAExecutionProvider") diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index 12b83c171c565..bce9b59ff0ea4 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -254,16 +254,9 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG continue; // unable to get input shape } - const auto is_static_shape = [](gsl::span shape) -> bool { - return std::all_of(shape.begin(), shape.end(), [](int64_t dim) { return dim >= 0; }); - }; - - if (!is_static_shape(*input_0_shape) || !is_static_shape(*input_1_shape)) { - continue; // input shape has dynamic dimensions - } - - if (*input_0_shape != *input_1_shape) { - continue; // input shapes do not match (no broadcasting support for now) + // Don't support broadcasting and dynamic dimensions for now. + if (!AreShapesStaticAndEqual(*input_0_shape, *input_1_shape)) { + continue; } } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc new file mode 100644 index 0000000000000..7b939c0685237 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ep_factory.h" +#include "../plugin_ep_utils.h" + +ExampleKernelEp::ExampleKernelEp(ExampleKernelEpFactory& factory, const OrtLogger& logger) + : OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized + factory_{factory}, + ort_api_{factory.GetOrtApi()}, + ep_api_{factory.GetEpApi()}, + name_{factory.GetEpName()}, + logger_{logger} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + + // Initialize the execution provider's function table + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + GetKernelRegistry = GetKernelRegistryImpl; + + // This is not a compiling EP, so don't need the following + Compile = nullptr; + ReleaseNodeComputeInfos = nullptr; + + IGNORE_ORTSTATUS(ort_api_.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + ("ExampleKernelEp has been created with name " + name_).c_str(), + ORT_FILE, __LINE__, __FUNCTION__)); +} + +ExampleKernelEp::~ExampleKernelEp() = default; + +/*static*/ +const char* ORT_API_CALL ExampleKernelEp::GetNameImpl(const OrtEp* this_ptr) noexcept { + const auto* ep = static_cast(this_ptr); + return ep->name_.c_str(); +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleKernelEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + try { + ExampleKernelEp* ep = static_cast(this_ptr); + + Ort::ConstGraph graph{ort_graph}; + std::vector all_nodes = graph.GetNodes(); + + if (all_nodes.empty()) { + return nullptr; // No nodes to process + } + + // Collect candidate nodes that this EP may support. + std::vector candidate_nodes; + + for (const auto& node : all_nodes) { + std::string op_type = node.GetOperatorType(); + + if (op_type == "Relu" || op_type == "Squeeze") { + candidate_nodes.push_back(node); + } else if (op_type == "Mul") { + std::vector inputs = node.GetInputs(); + + // Note: ONNX shape inference should ensure Mul has two inputs. + std::optional> input_0_shape = GetTensorShape(inputs[0]); + std::optional> input_1_shape = GetTensorShape(inputs[1]); + + if (!input_0_shape.has_value() || !input_1_shape.has_value()) { + continue; // Unable to get input shapes (non-tensor). + } + + if (!AreShapesStaticAndEqual(*input_0_shape, *input_1_shape)) { + continue; // Don't support broadcasting and dynamic dimensions. + } + + candidate_nodes.push_back(node); + } + } + + // Mark candidate nodes as supported if we have a registered kernel. + for (const auto& node : candidate_nodes) { + const OrtKernelDef* kernel_def = nullptr; + RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def)); + + if (kernel_def != nullptr) { + RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_AddSingleNode(graph_support_info, node)); + } + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleKernelEp::GetKernelRegistryImpl( + _In_ OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry) noexcept { + ExampleKernelEp* ep = static_cast(this_ptr); + + *kernel_registry = nullptr; + + // Get the cached kernel registry from parent factory to avoid recreating the kernel registry for every EP instance. + RETURN_IF_ERROR(ep->factory_.GetKernelRegistryForEp(*ep, kernel_registry)); + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.h new file mode 100644 index 0000000000000..35357ddf3f5e2 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +class ExampleKernelEpFactory; + +/// +/// Example EP that uses kernel registration. +/// +class ExampleKernelEp : public OrtEp { + public: + ExampleKernelEp(ExampleKernelEpFactory& factory, const OrtLogger& logger); + ~ExampleKernelEp(); + + const OrtApi& GetOrtApi() const { return ort_api_; } + const OrtEpApi& GetEpApi() const { return ep_api_; } + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetKernelRegistryImpl( + _In_ OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry) noexcept; + + static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept; + + ExampleKernelEpFactory& factory_; + const OrtApi& ort_api_; + const OrtEpApi& ep_api_; + std::string name_; + const OrtLogger& logger_; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc new file mode 100644 index 0000000000000..6017bf9dd9d1e --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc @@ -0,0 +1,206 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_factory.h" + +#include + +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" + +#include "ep.h" +#include "ep_kernel_registration.h" +#include "../plugin_ep_utils.h" + +ExampleKernelEpFactory::ExampleKernelEpFactory(const OrtApi& ort_api, const OrtEpApi& ep_api, + const OrtLogger& /*default_logger*/) + : OrtEpFactory{}, + ort_api_(ort_api), + ep_api_(ep_api) { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + + GetSupportedDevices = GetSupportedDevicesImpl; + + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + + CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; +} + +ExampleKernelEpFactory::~ExampleKernelEpFactory() { + if (kernel_registry_ != nullptr) { + Ort::GetEpApi().ReleaseKernelRegistry(kernel_registry_); + } +} + +OrtStatus* ExampleKernelEpFactory::GetKernelRegistryForEp(ExampleKernelEp& ep, + const OrtKernelRegistry** out_kernel_registry) { + *out_kernel_registry = nullptr; + + if (GetNumKernels() == 0) { + return nullptr; + } + + if (kernel_registry_ == nullptr) { + void* op_kernel_state = nullptr; // Optional state that is provided to kernels on creation (can be null). + const char* ep_name = ep.GetName(static_cast(&ep)); + + // This statement creates the kernel registry and caches it in the OrtEpFactory instance. + // We assume that all EPs created by this factory can use the same kernel registry. This may not be the + // case in a more complex OrtEpFactory that can create EP instances that are each configured for different + // hardware devices. In such a scenario, a different kernel registry may be created for each EP configuration. + RETURN_IF_ERROR(CreateKernelRegistry(ep_name, op_kernel_state, &kernel_registry_)); + } + + *out_kernel_registry = kernel_registry_; + return nullptr; +} + +/*static*/ +const char* ORT_API_CALL ExampleKernelEpFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_name_.c_str(); +} + +/*static*/ +const char* ORT_API_CALL ExampleKernelEpFactory::GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_.c_str(); +} + +/*static*/ +uint32_t ORT_API_CALL ExampleKernelEpFactory::GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id_; +} + +/*static*/ +const char* ORT_API_CALL ExampleKernelEpFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_version_.c_str(); +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleKernelEpFactory::GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* hw_devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); + + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *hw_devices[i]; + if (factory->ort_api_.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + // these can be returned as nullptr if you have nothing to add. + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api_.CreateKeyValuePairs(&ep_metadata); + factory->ort_api_.CreateKeyValuePairs(&ep_options); + + // random example using made up values + factory->ort_api_.AddKeyValuePair(ep_metadata, "supported_devices", "CrackGriffin 7+"); + factory->ort_api_.AddKeyValuePair(ep_options, "run_really_fast", "true"); + + // OrtEpDevice copies ep_metadata and ep_options. + OrtEpDevice* ep_device = nullptr; + auto* status = factory->ort_api_.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, + &ep_device); + + factory->ort_api_.ReleaseKeyValuePairs(ep_metadata); + factory->ort_api_.ReleaseKeyValuePairs(ep_options); + + if (status != nullptr) { + return status; + } + + // register the allocator info required by the EP. + // registering OrtMemoryInfo for host accessible memory would be done in an additional call. + // OrtReadOnlyAllocator + OrtDeviceMemoryType_DEFAULT allocator for use with initializers is optional. + // RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->default_memory_info_.get())); + // RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->readonly_memory_info_.get())); + + ep_devices[num_ep_devices++] = ep_device; + } + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleKernelEpFactory::CreateEpImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + const OrtSessionOptions* /*session_options*/, + const OrtLogger* logger, + OrtEp** ep) noexcept { + auto* factory = static_cast(this_ptr); + *ep = nullptr; + + if (num_devices != 1) { + return factory->ort_api_.CreateStatus(ORT_INVALID_ARGUMENT, + "ExampleKernelEpFactory only supports selection for one device."); + } + + auto actual_ep = std::make_unique(*factory, *logger); + *ep = actual_ep.release(); + + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleKernelEpFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { + delete static_cast(ep); +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleKernelEpFactory::CreateAllocatorImpl(OrtEpFactory* /*this_ptr*/, + const OrtMemoryInfo* /*memory_info*/, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + // Don't support custom allocators in this example for simplicity. A GPU EP would normally support allocators. + *allocator = nullptr; + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleKernelEpFactory::ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, + OrtAllocator* /*allocator*/) noexcept { + // Do nothing. +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleKernelEpFactory::CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + // Don't support data transfer in this example for simplicity. A GPU EP would normally support it. + *data_transfer = nullptr; + return nullptr; +} + +/*static*/ +bool ORT_API_CALL ExampleKernelEpFactory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return false; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleKernelEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFactory* /*this_ptr*/, + const OrtMemoryDevice* /*memory_device*/, + const OrtKeyValuePairs* /*stream_opts*/, + OrtSyncStreamImpl** stream) noexcept { + // Don't support sync streams in this example. + *stream = nullptr; + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h new file mode 100644 index 0000000000000..9ddbeee585115 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +class ExampleKernelEp; + +/// +/// EP factory that creates an OrtEp instance that uses kernel registration. +/// +class ExampleKernelEpFactory : public OrtEpFactory { + public: + ExampleKernelEpFactory(const OrtApi& ort_api, const OrtEpApi& ep_api, const OrtLogger& default_logger); + ~ExampleKernelEpFactory(); + + const OrtApi& GetOrtApi() const { return ort_api_; } + const OrtEpApi& GetEpApi() const { return ep_api_; } + const std::string& GetEpName() const { return ep_name_; } + + // Called by child OrtEp instances to retrieve the cached kernel registry for that EP. + OrtStatus* GetKernelRegistryForEp(ExampleKernelEp& ep, /*out*/ const OrtKernelRegistry** kernel_registry); + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; + + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; + + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept; + + static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept; + + static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept; + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept; + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept; + + const OrtApi& ort_api_; + const OrtEpApi& ep_api_; + const std::string ep_name_{"ExampleKernelEp"}; + const std::string vendor_{"Contoso2"}; // EP vendor name + const uint32_t vendor_id_{0xB358}; // EP vendor ID + const std::string ep_version_{"0.1.0"}; // EP version + + // Cached kernel registry used by all OrtEp instances created by this factory. Refer to OrtEp::GetKernelRegistry. + // + // Note: If this factory instead created EP instances that each supported different hardware configurations, then + // the factory could cache a different kernel registry per EP configuration. + OrtKernelRegistry* kernel_registry_ = nullptr; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc new file mode 100644 index 0000000000000..b9518786f3a04 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "ep_kernel_registration.h" +#include "kernels/utils.h" + +// Table of BuildKernelCreateInfo functions for each operator +static const BuildKernelCreateInfoFn build_kernel_create_info_funcs[] = { + // Mul version 14 + BuildKernelCreateInfo, + + // Relu version 14 + BuildKernelCreateInfo, + + // Support Squeeze 21, 23, and 24. + // Note: end versions are inclusive. + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +}; + +size_t GetNumKernels() { + return std::size(build_kernel_create_info_funcs); +} + +static OrtStatus* RegisterKernels(Ort::KernelRegistry& kernel_registry, const char* ep_name, + void* create_kernel_state) { + for (auto& build_func : build_kernel_create_info_funcs) { + KernelCreateInfo kernel_create_info = {}; + RETURN_IF_ERROR(build_func(ep_name, create_kernel_state, &kernel_create_info)); + + if (kernel_create_info.kernel_def != nullptr) { + RETURN_IF_ERROR(kernel_registry.AddKernel(kernel_create_info.kernel_def, + kernel_create_info.kernel_create_func, + kernel_create_info.kernel_create_func_state)); + } + } + + return nullptr; +} + +OrtStatus* CreateKernelRegistry(const char* ep_name, void* create_kernel_state, + OrtKernelRegistry** out_kernel_registry) { + *out_kernel_registry = nullptr; + + if (GetNumKernels() == 0) { + return nullptr; + } + + try { + Ort::KernelRegistry kernel_registry; + Ort::Status status{RegisterKernels(kernel_registry, ep_name, create_kernel_state)}; + + *out_kernel_registry = status.IsOK() ? kernel_registry.release() : nullptr; + return status.release(); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h new file mode 100644 index 0000000000000..d88cdb35afa6c --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "../plugin_ep_utils.h" + +size_t GetNumKernels(); + +OrtStatus* CreateKernelRegistry(const char* ep_name, void* create_kernel_state, OrtKernelRegistry** kernel_registry); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_lib.def b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_lib.def new file mode 100644 index 0000000000000..f2924fb2b1f43 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_lib.def @@ -0,0 +1,5 @@ +LIBRARY "example_plugin_ep_kernel_registry.dll" +EXPORTS + CreateEpFactories @1 + ReleaseEpFactory @2 + diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_lib.lds b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_lib.lds new file mode 100644 index 0000000000000..a6d2ef09a7b16 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_lib.lds @@ -0,0 +1,7 @@ +VERS_1.0.0 { + global: + CreateEpFactories; + ReleaseEpFactory; + local: + *; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_lib_entry.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_lib_entry.cc new file mode 100644 index 0000000000000..74f473b9b0320 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_lib_entry.cc @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include "ep_factory.h" + +// To make symbols visible on macOS/iOS +#ifdef __APPLE__ +#define EXPORT_SYMBOL __attribute__((visibility("default"))) +#else +#define EXPORT_SYMBOL +#endif + +extern "C" { +// +// Public symbols +// +EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + const OrtLogger* default_logger, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + const OrtEpApi* ep_api = ort_api->GetEpApi(); + + // Manual init for the C++ API + Ort::InitApi(ort_api); + + std::unique_ptr factory = std::make_unique(*ort_api, *ep_api, *default_logger); + + if (max_factories < 1) { + return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + factories[0] = factory.release(); + *num_factories = 1; + + return nullptr; +} + +EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} + +} // extern "C" diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc new file mode 100644 index 0000000000000..30f83e1771dd7 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "base.h" + +BaseKernelImpl::BaseKernelImpl(const OrtKernelInfo* info, void* state) : info_{info}, state_{state} { + ort_version_supported = ORT_API_VERSION; + Compute = ComputeImpl; + Release = ReleaseImpl; +} + +/*static*/ +OrtStatus* ORT_API_CALL BaseKernelImpl::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { + try { + BaseKernelImpl* base_kernel = static_cast(this_ptr); + return base_kernel->DoCompute(kernel_ctx); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } +} + +/*static*/ +void ORT_API_CALL BaseKernelImpl::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h new file mode 100644 index 0000000000000..c4afe1b2e0670 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../../plugin_ep_utils.h" + +// Base class for kernel implementations. +// +// Note: BaseKernelImpl has virtual functions so care should be taken when casting BaseKernelImpl to a OrtKernelImpl, +// which is a C API struct type. Specifically, a static_cast or implicit cast should be used. A reinterpret_cast +// will result in an invalid object due to the presence of the vtable. +class BaseKernelImpl : public OrtKernelImpl { + public: + BaseKernelImpl(const OrtKernelInfo* info, void* state); + virtual ~BaseKernelImpl() = default; + + static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; + + private: + // Derived classes implement DoCompute. + // DoCompute is called by BaseKernelImpl::ComputeImpl, which also catches exceptions thrown by DoCompute + // implementations and converts them into OrtStatus*. + virtual OrtStatus* DoCompute(OrtKernelContext* kernel_ctx) = 0; + + protected: + const OrtKernelInfo* info_; + void* state_; // Custom state passed from OrtEp +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc new file mode 100644 index 0000000000000..979dc5e9c1303 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "mul.h" +#include "utils.h" + +// Defines a kernel creation function for version 14 of Mul. +ONNX_OPERATOR_KERNEL_EX( + Mul, + kOnnxDomain, + /*version*/ 14, // Equivalent to start_version: 14, end_version: 14 (inclusive) + (Ort::KernelDefBuilder() + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + Mul) + +Mul::Mul(const OrtKernelInfo* info, void* state, PrivateTag) : BaseKernelImpl(info, state) {} + +/*static*/ +OrtStatus* Mul::Create(const OrtKernelInfo* info, void* state, + /*out*/ std::unique_ptr& result) { + // Note: can do basic validation or preprocessing via the OrtKernelInfo APIs. + result = std::make_unique(info, state, PrivateTag{}); + return nullptr; +} + +OrtStatus* Mul::DoCompute(OrtKernelContext* kernel_ctx) { + Ort::KernelContext kernel_context(kernel_ctx); + static_cast(this->state_); // NOTE: Unused in this example. + static_cast(this->info_); // NOTE: Unused in this example. + + gsl::span input0; + gsl::span input1; + std::vector shape0; + std::vector shape1; + + RETURN_IF_ERROR(GetKernelInputDataAndShape(kernel_context, 0, input0, shape0)); + RETURN_IF_ERROR(GetKernelInputDataAndShape(kernel_context, 1, input1, shape1)); + RETURN_IF(shape0 != shape1, Ort::GetApi(), "Mul kernel doesn't support broadcasting."); // Checked by GetCapability + + Ort::UnownedValue output = kernel_context.GetOutput(0, shape0); + float* output_data = output.GetTensorMutableData(); + + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] * input1[i]; + } + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h new file mode 100644 index 0000000000000..882a19a13e23e --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "base.h" +#include "../../plugin_ep_utils.h" + +class Mul : public BaseKernelImpl { + private: + struct PrivateTag {}; + + public: + static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel); + Mul(const OrtKernelInfo* info, void* state, PrivateTag); + + private: + OrtStatus* DoCompute(OrtKernelContext* kernel_ctx) override; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc new file mode 100644 index 0000000000000..82444b815a1ee --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "relu.h" + +#include +#include +#include + +#include "utils.h" + +// Defines a kernel creation function for version 14 of Relu. +ONNX_OPERATOR_KERNEL_EX( + Relu, + kOnnxDomain, + /*version*/ 14, // Equivalent to start_version: 14, end_version: 14 (inclusive) + (Ort::KernelDefBuilder() + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) + .AddInputOutputMutableAlias(0, 0)), + Relu) + +Relu::Relu(const OrtKernelInfo* info, void* state, PrivateTag) : BaseKernelImpl(info, state) {} + +/*static*/ +OrtStatus* Relu::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) { + Ort::ConstKernelInfo kernel_info(info); + kernel = std::make_unique(info, state, PrivateTag{}); + return nullptr; +} + +OrtStatus* Relu::DoCompute(OrtKernelContext* kernel_ctx) { + Ort::KernelContext kernel_context(kernel_ctx); + static_cast(this->state_); // NOTE: Unused in this example. + static_cast(this->info_); // NOTE: Unused in this example. + + gsl::span input0; + std::vector shape0; + RETURN_IF_ERROR(GetKernelInputDataAndShape(kernel_context, 0, input0, shape0)); + + Ort::UnownedValue output = kernel_context.GetOutput(0, shape0); + float* output_data = output.GetTensorMutableData(); + + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = std::max(0.0f, input0[i]); + } + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h new file mode 100644 index 0000000000000..4f5ba8bc0e77b --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "base.h" +#include "../../plugin_ep_utils.h" + +class Relu : public BaseKernelImpl { + private: + struct PrivateTag {}; + + public: + static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel); + Relu(const OrtKernelInfo* info, void* state, PrivateTag); + + private: + OrtStatus* DoCompute(OrtKernelContext* kernel_ctx) override; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc new file mode 100644 index 0000000000000..5311911a8c413 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "squeeze.h" + +#include +#include + +#include "utils.h" + +// Support ONNX Squeeze versions 21, 23, and 24. +// Kernel creation functions are typically defined separately for new operator versions to account for things like new +// data types. One could technically support all three versions with a single call to +// ONNX_OPERATOR_VERSIONED_KERNEL_EX(Squeeze, kOnnxDomain, 21, 24, ...), but this example shows the more common usage. + +// ONNX Squeeze version 21 +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + /*start_version*/ 21, /*end_version (inclusive)*/ 22, + (Ort::KernelDefBuilder() + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) + .AddTypeConstraint("axes", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .AddInputOutputAlias(0, 0)), + Squeeze) + +// ONNX Squeeze version 23 +ONNX_OPERATOR_KERNEL_EX( + Squeeze, + kOnnxDomain, + /*version*/ 23, // Equivalent to start_version: 23, end_version: 23 + (Ort::KernelDefBuilder() + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) + .AddTypeConstraint("axes", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .AddInputOutputAlias(0, 0)), + Squeeze) + +// ONNX Squeeze version 24. +ONNX_OPERATOR_KERNEL_EX( + Squeeze, + kOnnxDomain, + /*version*/ 24, // Equivalent start_version: 24, end_version: 24 + (Ort::KernelDefBuilder() + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) + .AddTypeConstraint("axes", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .AddInputOutputAlias(0, 0)), + Squeeze) + +Squeeze::Squeeze(const OrtKernelInfo* info, void* state, PrivateTag) : BaseKernelImpl(info, state) {} + +/*static*/ +OrtStatus* Squeeze::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) { + Ort::ConstKernelInfo kernel_info(info); + kernel = std::make_unique(info, state, PrivateTag{}); + return nullptr; +} + +static int64_t HandleNegativeAxis(int64_t axis, int64_t tensor_rank) { + return axis < 0 ? axis + tensor_rank : axis; +} + +static std::vector ComputeOutputShape(gsl::span input_shape, gsl::span axes) { + size_t j = 0; + std::vector output_shape; + auto num_dimensions = input_shape.size(); + + // Handle negative axis, then resort and uniq. + std::vector axes_corrected(axes.size()); + for (size_t i = 0; i < axes.size(); i++) { + axes_corrected[i] = HandleNegativeAxis(axes[i], num_dimensions); + } + std::sort(axes_corrected.begin(), axes_corrected.end()); + axes_corrected.erase(std::unique(axes_corrected.begin(), axes_corrected.end()), axes_corrected.end()); + + for (size_t i = 0; i < num_dimensions; ++i) { + if ((j < axes_corrected.size() && axes_corrected[j] == static_cast(i)) || + (axes_corrected.size() == 0 && input_shape[i] == 1)) { + assert(input_shape[i] == 1); + ++j; + continue; + } + output_shape.push_back(input_shape[i]); + } + return output_shape; +} + +OrtStatus* Squeeze::DoCompute(OrtKernelContext* kernel_ctx) { + Ort::KernelContext kernel_context(kernel_ctx); + static_cast(this->state_); // NOTE: Unused in this example. + + gsl::span input0; + std::vector shape0; + RETURN_IF_ERROR(GetKernelInputDataAndShape(kernel_context, 0, input0, shape0)); + + size_t num_inputs = kernel_context.GetInputCount(); + std::vector axes; + + if (num_inputs == 2) { + // Axes is an explicit input. + gsl::span axes_input; + std::vector axes_shape; + RETURN_IF_ERROR(GetKernelInputDataAndShape(kernel_context, 1, axes_input, axes_shape)); + assert(axes_shape.size() == 1); + + axes.assign(axes_input.begin(), axes_input.end()); + } + + std::vector output_shape = ComputeOutputShape(shape0, axes); + Ort::UnownedValue output = kernel_context.GetOutput(0, output_shape); + float* output_data = output.GetTensorMutableData(); + size_t num_bytes = output.GetTensorSizeInBytes(); + + if (input0.data() != output_data) { // Don't copy if src == dst + // This uses a memcpy because the input and output are both located in the EP's device memory (i.e., cpu memory). + // Normally, an EP would use a OrtDataTransferImpl to generically handle copies where the source and destination + // could be on different devices. + memcpy(output_data, input0.data(), num_bytes); + } + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h new file mode 100644 index 0000000000000..9faf91c1d2b3c --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "base.h" +#include "../../plugin_ep_utils.h" + +class Squeeze : public BaseKernelImpl { + private: + struct PrivateTag {}; + + public: + static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel); + Squeeze(const OrtKernelInfo* info, void* state, PrivateTag); + + private: + OrtStatus* DoCompute(OrtKernelContext* kernel_ctx) override; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h new file mode 100644 index 0000000000000..615ee3911108a --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../../plugin_ep_utils.h" + +/// +/// Gets an OrtDataType for a tensor type. Throws on error. +/// +/// +/// +inline const OrtDataType* GetTensorType(ONNXTensorElementDataType elem_type) { + const OrtEpApi& ep_api = Ort::GetEpApi(); + const OrtDataType* result = nullptr; + + Ort::ThrowOnError(ep_api.GetTensorDataType(elem_type, &result)); + return result; +} + +/// +/// Contains information to create a kernel: kernel definition, creation function + state. +/// +struct KernelCreateInfo { + KernelCreateInfo() = default; + KernelCreateInfo(Ort::KernelDef def, OrtKernelCreateFunc func, void* state) + : kernel_def{std::move(def)}, kernel_create_func{func}, kernel_create_func_state{state} {} + + Ort::KernelDef kernel_def{nullptr}; + OrtKernelCreateFunc kernel_create_func = nullptr; + void* kernel_create_func_state = nullptr; +}; + +using BuildKernelCreateInfoFn = OrtStatus* (*)(const char*, void*, KernelCreateInfo*); + +template +OrtStatus* BuildKernelCreateInfo(const char* ep_name, void* create_func_state, /*out*/ KernelCreateInfo* result); + +template <> +inline OrtStatus* BuildKernelCreateInfo(const char* /*ep_name*/, void* /*create_func_state*/, + /*out*/ KernelCreateInfo* result) { + result->kernel_def = Ort::KernelDef{nullptr}; + result->kernel_create_func = nullptr; + result->kernel_create_func_state = nullptr; + return nullptr; +} + +static constexpr const char* kOnnxDomain = ""; + +// Naming convention for operator kernel classes with a start and end version range. +#define ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(domain, startver, endver, name) \ + example_ep_##name##_##domain##_ver##startver##_##endver + +// Naming convention for operator kernel classes for a single version +#define ONNX_OPERATOR_KERNEL_CLASS_NAME(domain, version, name) \ + ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(domain, version, version, name) + +// Defines a function of type BuildKernelCreateInfoFn for a kernel implementation with a start and end version range. +#define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, builder, kernel_class) \ + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(domain, startver, endver, name); \ + template <> \ + OrtStatus* \ + BuildKernelCreateInfo( \ + const char* ep_name, \ + void* create_kernel_state, \ + KernelCreateInfo* result) { \ + try { \ + Ort::KernelDef kernel_def = builder.SetOperatorType(#name) \ + .SetDomain(domain) \ + .SetSinceVersion(startver, endver) \ + .SetExecutionProvider(ep_name) \ + .Build(); \ + \ + auto kernel_create_func = [](void* state, const OrtKernelInfo* info, \ + OrtKernelImpl** kernel_out) noexcept -> OrtStatus* { \ + *kernel_out = nullptr; \ + \ + std::unique_ptr kernel; \ + RETURN_IF_ERROR(kernel_class::Create(info, state, kernel)); \ + *kernel_out = kernel.release(); \ + return nullptr; \ + }; \ + \ + *result = KernelCreateInfo(std::move(kernel_def), kernel_create_func, create_kernel_state); \ + } catch (const Ort::Exception& ex) { \ + Ort::Status status(ex); \ + return status.release(); \ + } catch (const std::exception& ex) { \ + Ort::Status status(ex.what(), ORT_EP_FAIL); \ + return status.release(); \ + } \ + return nullptr; \ + } + +// Defines a function of type BuildKernelCreateInfoFn for a kernel implementation with a start version. +#define ONNX_OPERATOR_KERNEL_EX(name, domain, version, builder, kernel_class) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, version, version, builder, kernel_class) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.cc index 385238961eb66..2fd9fe542b0ce 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.cc @@ -3,7 +3,7 @@ #include "ep.h" -#include +#include #include #include #include diff --git a/onnxruntime/test/autoep/library/plugin_ep_utils.h b/onnxruntime/test/autoep/library/plugin_ep_utils.h index 3636e69be2d44..d14186425458f 100644 --- a/onnxruntime/test/autoep/library/plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/plugin_ep_utils.h @@ -3,6 +3,8 @@ #pragma once +#include +#include #include #include #include @@ -12,12 +14,12 @@ #include "onnxruntime_cxx_api.h" #undef ORT_API_MANUAL_INIT -#define RETURN_IF_ERROR(fn) \ - do { \ - OrtStatus* _status = (fn); \ - if (_status != nullptr) { \ - return _status; \ - } \ +#define RETURN_IF_ERROR(fn) \ + do { \ + Ort::Status _status{(fn)}; \ + if (!_status.IsOK()) { \ + return _status.release(); \ + } \ } while (0) #define RETURN_IF(cond, ort_api, msg) \ @@ -129,3 +131,50 @@ inline std::optional> GetTensorShape(Ort::ConstValueInfo va const auto type_shape = type_info.GetTensorTypeAndShapeInfo(); return type_shape.GetShape(); } + +// Check if two shapes are static (no dynamic dimensions) and equal. +inline bool AreShapesStaticAndEqual(gsl::span shape0, gsl::span shape1) { + const auto is_static_shape = [](gsl::span shape) -> bool { + return std::all_of(shape.begin(), shape.end(), [](int64_t dim) { return dim >= 0; }); + }; + + if (!is_static_shape(shape0) || !is_static_shape(shape1)) { + return false; // a shape has dynamic dimensions + } + + return shape0 == shape1; +} + +template +inline ONNXTensorElementDataType GetTensorElemDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; +} + +template <> +inline ONNXTensorElementDataType GetTensorElemDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; +} + +template <> +inline ONNXTensorElementDataType GetTensorElemDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; +} + +template +inline OrtStatus* GetKernelInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) { + Ort::ConstValue input = kernel_context.GetInput(index); + auto type_shape = input.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + RETURN_IF(elem_type != GetTensorElemDataType(), Ort::GetApi(), + "EP expected kernel input of tensor type"); + + const T* float_data = input.GetTensorData(); + size_t num_elems = type_shape.GetElementCount(); + data = gsl::span(float_data, num_elems); + shape = type_shape.GetShape(); + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/readme.md b/onnxruntime/test/autoep/library/readme.md index 4750c2c792b74..04aaa5ca88973 100644 --- a/onnxruntime/test/autoep/library/readme.md +++ b/onnxruntime/test/autoep/library/readme.md @@ -18,6 +18,10 @@ used for testing and as reference examples. Contains a compiling plugin execution provider that registers its own virtual hardware device. Virtual devices can be used for cross compiling models for different targets. +- `example_plugin_ep_kernel_registry/` + Contains a basic plugin execution provider that registers operator kernels with ONNX Runtime, as opposed to compiling + nodes. + - `plugin_ep_utils.h` Common utilities for the example plugin execution provider implementations. diff --git a/onnxruntime/test/autoep/test_autoep_utils.cc b/onnxruntime/test/autoep/test_autoep_utils.cc index 07e7c12c4ad99..0de64690c0b3e 100644 --- a/onnxruntime/test/autoep/test_autoep_utils.cc +++ b/onnxruntime/test/autoep/test_autoep_utils.cc @@ -28,6 +28,11 @@ const Utils::ExamplePluginInfo Utils::example_ep_virt_gpu_info( // This EP's name is hardcoded to the following "EpVirtualGpu"); +const Utils::ExamplePluginInfo Utils::example_ep_kernel_registry_info( + GetSharedLibraryFileName(ORT_TSTR("example_plugin_ep_kernel_registry")), + "example_plugin_ep_kernel_registry", + "ExampleKernelEp"); + void Utils::GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& ep_device) { const OrtApi& c_api = Ort::GetApi(); const OrtEpDevice* const* ep_devices = nullptr; diff --git a/onnxruntime/test/autoep/test_autoep_utils.h b/onnxruntime/test/autoep/test_autoep_utils.h index ae930644779a0..c1e118a001cc6 100644 --- a/onnxruntime/test/autoep/test_autoep_utils.h +++ b/onnxruntime/test/autoep/test_autoep_utils.h @@ -22,8 +22,9 @@ struct Utils { std::string ep_name; }; - static const ExamplePluginInfo example_ep_info; // example_plugin_ep.dll - static const ExamplePluginInfo example_ep_virt_gpu_info; // example_plugin_ep_virt_gpu.dll + static const ExamplePluginInfo example_ep_info; // example_plugin_ep.dll + static const ExamplePluginInfo example_ep_virt_gpu_info; // example_plugin_ep_virt_gpu.dll + static const ExamplePluginInfo example_ep_kernel_registry_info; // example_plugin_ep_kernel_registry.dll // get the OrtEpDevice for an arbitrary EP from the environment static void GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& ep_device); diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 4c88d3ec2e0f3..bb391bb0bca23 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -84,7 +84,7 @@ void RunPartiallySupportedModelWithPluginEp(const Ort::SessionOptions& session_o // Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) { RegisteredEpDeviceUniquePtr example_ep; - Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep); + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); Ort::ConstEpDevice plugin_ep_device(example_ep.get()); // Create session with example plugin EP @@ -99,7 +99,7 @@ TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) { // Uses the PREFER_CPU policy to append the example plugin EP to the session. TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { RegisteredEpDeviceUniquePtr example_ep; - Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep); + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); { // PREFER_CPU pick our example EP over ORT CPU EP. TODO: Actually assert this. @@ -111,7 +111,7 @@ TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { TEST(OrtEpLibrary, PluginEp_AppendV2_PartiallySupportedModelInference) { RegisteredEpDeviceUniquePtr example_ep; - Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep); + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); Ort::ConstEpDevice plugin_ep_device(example_ep.get()); // Create session with example plugin EP @@ -126,7 +126,7 @@ TEST(OrtEpLibrary, PluginEp_AppendV2_PartiallySupportedModelInference) { // This test uses the OrtCompileApi but could also be done by setting the appropriate session option configs. TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { RegisteredEpDeviceUniquePtr example_ep; - Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep); + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); Ort::ConstEpDevice plugin_ep_device(example_ep.get()); { @@ -156,7 +156,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { // Generate an EPContext model with a plugin EP that uses a virtual GPU. TEST(OrtEpLibrary, PluginEp_VirtGpu_GenEpContextModel) { RegisteredEpDeviceUniquePtr example_ep; - Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_virt_gpu_info, example_ep); + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_virt_gpu_info, example_ep)); Ort::ConstEpDevice plugin_ep_device(example_ep.get()); { @@ -192,7 +192,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel_ErrorOutputModelExists_AutoGenOutp std::filesystem::remove(expected_output_model_file); RegisteredEpDeviceUniquePtr example_ep; - Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep); + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); Ort::ConstEpDevice plugin_ep_device(example_ep.get()); std::unordered_map ep_options; @@ -238,5 +238,49 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel_ErrorOutputModelExists_AutoGenOutp std::filesystem::remove(expected_output_model_file); } + +TEST(OrtEpLibrary, KernelPluginEp_Inference) { + RegisteredEpDeviceUniquePtr example_kernel_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_kernel_registry_info, + example_kernel_ep)); + Ort::ConstEpDevice plugin_ep_device(example_kernel_ep.get()); + + // Create session with example kernel-based plugin EP + Ort::SessionOptions session_options; + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP. + + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + // This model has Squeeze, Mul, and Relu nodes. The example plugin EP supports all nodes using registered kernels. + Ort::Session session(*ort_env, ORT_TSTR("testdata/squeeze_mul_relu.onnx"), session_options); + + // Create inputs + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::array a_shape = {3, 1, 2}; + std::array b_shape = {3, 2}; + + std::array a_data = {1.f, -2.f, 3.f, 4.f, -5.f, 6.f}; + std::array b_data = {2.f, 3.f, 4.f, -5.f, 6.f, 7.f}; + + std::vector ort_inputs{}; + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, a_data.data(), a_data.size(), a_shape.data(), a_shape.size())); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, b_data.data(), b_data.size(), b_shape.data(), b_shape.size())); + + std::array ort_input_names{"A", "B"}; + + // Run session and get outputs + std::array output_names{"C"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(4, 0, 24, 0, 0, 84)); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 99fd3c18e94ef..411629535254d 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -57,7 +57,6 @@ static void RunAttentionTest( int max_sequence_length = 0, const bool disable_cpu = false, const bool disable_cuda = false, - const bool disable_rocm = false, const bool disable_dml = false, const bool disable_webgpu = false, std::vector qkv_sizes = {}, @@ -72,22 +71,21 @@ static void RunAttentionTest( int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant && !disable_cuda; - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !is_weights_constant && !disable_rocm; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu; bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()) && !disable_webgpu; int head_size = hidden_size / number_of_heads; - if (enable_cpu || enable_cuda || enable_rocm || enable_dml || enable_webgpu) { + if (enable_cpu || enable_cuda || enable_dml || enable_webgpu) { OpTester tester("Attention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(number_of_heads)); tester.AddAttribute("unidirectional", static_cast(is_unidirectional ? 1 : 0)); tester.AddAttribute("past_present_share_buffer", static_cast(past_present_share_buffer ? 1 : 0)); tester.AddAttribute("mask_filter_value", static_cast(-10000.0f)); - if (use_scale && !enable_rocm) { + if (use_scale) { tester.AddAttribute("scale", static_cast(1.f / sqrt(head_size))); } - if (do_neox_rotary && !enable_rocm) { + if (do_neox_rotary) { tester.AddAttribute("do_rotary", static_cast(do_neox_rotary ? 1 : 0)); } @@ -241,18 +239,6 @@ static void RunAttentionTest( tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (enable_rocm) { - std::vector> execution_providers; - execution_providers.push_back(DefaultRocmExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - - if (enable_rocm) { - std::vector> execution_providers; - execution_providers.push_back(DefaultRocmExecutionProvider(/*test_tunable_op=*/true)); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (enable_cpu) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); @@ -297,7 +283,6 @@ static void RunAttentionTest( int max_sequence_length = 0, const bool disable_cpu = false, const bool disable_cuda = false, - const bool disable_rocm = false, const bool disable_dml = false, const bool disable_webgpu = false, const std::vector qkv_sizes = {}, @@ -310,13 +295,13 @@ static void RunAttentionTest( batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, - disable_cpu, disable_cuda, disable_rocm, disable_dml, disable_webgpu, qkv_sizes, attention_bias_data, + disable_cpu, disable_cuda, disable_dml, disable_webgpu, qkv_sizes, attention_bias_data, kv_sequence_length, past_present_share_buffer, use_scale, do_neox_rotary); RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, - disable_cpu, disable_cuda, disable_rocm, disable_dml, disable_webgpu, qkv_sizes, attention_bias_data, + disable_cpu, disable_cuda, disable_dml, disable_webgpu, qkv_sizes, attention_bias_data, kv_sequence_length, past_present_share_buffer, use_scale, do_neox_rotary); } @@ -383,11 +368,10 @@ TEST(ContribOpAttentionTest, AttentionBatch1WithQKVAttr1) { 3.1967618465423584f, 0.51903456449508667f, 0.63051539659500122f, 2.9394614696502686f, 0.65332180261611938f, 1.000949501991272f, 0.74175024032592773f, 2.8231701850891113f}; - constexpr bool disable_rocm = true; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, false, false, disable_rocm, false, false, qkv_sizes); + 0, false, false, false, false, qkv_sizes); } TEST(ContribOpAttentionTest, AttentionBatch1WithQKVAttr2) { @@ -421,11 +405,10 @@ TEST(ContribOpAttentionTest, AttentionBatch1WithQKVAttr2) { std::vector output_data = { 0.64932525157928467f, 0.79390722513198853f, 0.64932847023010254f, 0.79375863075256348f}; - constexpr bool disable_rocm = true; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, false, false, disable_rocm, false, false, qkv_sizes); + 0, false, false, false, false, qkv_sizes); } TEST(ContribOpAttentionTest, AttentionBatch1AttentionBias) { @@ -461,12 +444,11 @@ TEST(ContribOpAttentionTest, AttentionBatch1AttentionBias) { constexpr bool disable_cpu = false; constexpr bool disable_cuda = false; - constexpr bool disable_rocm = false; constexpr bool disable_dml = false; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, false, qkv_sizes, attention_bias); + 0, disable_cpu, disable_cuda, disable_dml, false, qkv_sizes, attention_bias); } TEST(ContribOpAttentionTest, AttentionBatch2AttentionBias) { @@ -507,12 +489,11 @@ TEST(ContribOpAttentionTest, AttentionBatch2AttentionBias) { constexpr bool disable_cpu = false; constexpr bool disable_cuda = false; - constexpr bool disable_rocm = false; constexpr bool disable_dml = false; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, false, qkv_sizes, attention_bias); + 0, disable_cpu, disable_cuda, disable_dml, false, qkv_sizes, attention_bias); } TEST(ContribOpAttentionTest, AttentionBatch1_Float16) { @@ -859,7 +840,7 @@ void RawAttentionEmptyPastState(bool past_present_share_buffer) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional, use_past_state, past_sequence_length, &past_data, &present_data, - AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, sequence_length, true, false, true, disable_dml, true, + AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, sequence_length, true, false, disable_dml, true, {}, {}, 0, true); } } @@ -1042,7 +1023,7 @@ void RawAttentionPastStateBatch1(bool past_present_share_buffer) { batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional, use_past_state, past_sequence_length, &past_data, &present_data, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, past_sequence_length + sequence_length + 4, - true, false, true, disable_dml, true, {}, {}, 0, true); + true, false, disable_dml, true, {}, {}, 0, true); } } @@ -1175,7 +1156,7 @@ void RawAttentionPastStateBatch2(bool past_present_share_buffer) { batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional, use_past_state, past_sequence_length, &past_data, &present_data, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, past_sequence_length + sequence_length, - true, false, true, disable_dml, true, {}, {}, 0, true); + true, false, disable_dml, true, {}, {}, 0, true); } } @@ -1300,7 +1281,7 @@ void RawAttentionPastStateBatch2WithPadding(bool past_present_share_buffer) { use_past_state, past_sequence_length, &past_data, &present_data, AttentionMaskType::MASK_1D_END_START, 0, past_sequence_length + sequence_length + 4, - true, false, true, disable_dml, true, {}, {}, 0, true); + true, false, disable_dml, true, {}, {}, 0, true); } } @@ -1687,7 +1668,7 @@ TEST(ContribOpAttentionTest, AttentionWithNormFactor) { batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/, - false /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, false /*disable_dml*/, + false /*disable_cpu*/, false /*disable_cuda*/, false /*disable_dml*/, false /*disable_webgpu*/, {} /*qkv_sizes*/, {} /*attention_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, true /*use_scale*/); } @@ -1721,7 +1702,7 @@ TEST(ContribOpAttentionTest, AttentionWithNeoXRotaryEmbedding) { batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/, - true /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, disable_dml, + true /*disable_cpu*/, false /*disable_cuda*/, disable_dml, true /*disable_webgpu*/, {} /*qkv_sizes*/, {} /*attention_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, true /*use_scale*/, true /*use_neox_rotary_embedding*/); } @@ -1983,7 +1964,7 @@ TEST(ContribOpAttentionTest, Attention4DMask) { batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, AttentionMaskType::MASK_4D_MEGATRON, input_hidden_size, max_sequence_length, - disable_cpu, /* disable_cuda */ false, /* disable_rocm */ false, /* disable_dml */ false, /* disable_webgpu */ true); + disable_cpu, /* disable_cuda */ false, /* disable_dml */ false, /* disable_webgpu */ true); } TEST(ContribOpAttentionTest, AttentionMaskIndexOutOfRange) { @@ -2137,10 +2118,9 @@ static void RunModelWithRandomInput( float gpu_threshold = is_float16 ? 0.5f : 0.005f; constexpr float cpu_threshold = 0.002f; bool enable_cuda = HasCudaEnvironment(is_float16 ? 530 : 0); - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get() && !is_float16); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()); - if (enable_cuda || enable_rocm || enable_dml) { + if (enable_cuda || enable_dml) { OpTester test("Attention", 1, onnxruntime::kMSDomain); test.AddAttribute("num_heads", num_heads); if (is_float16) { @@ -2162,8 +2142,6 @@ static void RunModelWithRandomInput( execution_providers.push_back(DefaultCudaExecutionProvider()); } else if (enable_dml) { execution_providers.push_back(DefaultDmlExecutionProvider()); - } else { - execution_providers.push_back(DefaultRocmExecutionProvider()); } test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 20eea2138340f..f875c710046ac 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -83,11 +83,6 @@ void RunGptBeamSearchFp32() { session_options.AppendExecutionProvider_CUDA_V2(cuda_options); #endif -#ifdef USE_ROCM - OrtROCMProviderOptions rocm_options; - session_options.AppendExecutionProvider_ROCM(rocm_options); -#endif - // The ONNX model is generated like the following: // python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 // --output tiny_gpt2_beamsearch_fp16.onnx --use_gpu --max_length 20 @@ -177,8 +172,7 @@ TEST(BeamSearchTest, GptBeamSearchFp16) { constexpr int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); - if (enable_cuda || enable_rocm) { + if (enable_cuda) { Ort::SessionOptions session_options; #ifdef USE_CUDA OrtCUDAProviderOptionsV2 cuda_options; @@ -186,11 +180,6 @@ TEST(BeamSearchTest, GptBeamSearchFp16) { session_options.AppendExecutionProvider_CUDA_V2(cuda_options); #endif -#ifdef USE_ROCM - OrtROCMProviderOptions rocm_options; - session_options.AppendExecutionProvider_ROCM(rocm_options); -#endif - // The ONNX model is generated like the following: // python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 // --output tiny_gpt2_beamsearch_fp16.onnx -p fp16 --use_gpu --max_length 20 @@ -272,8 +261,7 @@ TEST(BeamSearchTest, GptBeamSearchWithInitDecoderFp16) { constexpr int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); - if (enable_cuda || enable_rocm) { + if (enable_cuda) { Ort::SessionOptions session_options; #ifdef USE_CUDA OrtCUDAProviderOptionsV2 cuda_options; @@ -281,11 +269,6 @@ TEST(BeamSearchTest, GptBeamSearchWithInitDecoderFp16) { session_options.AppendExecutionProvider_CUDA_V2(cuda_options); #endif -#ifdef USE_ROCM - OrtROCMProviderOptions rocm_options; - session_options.AppendExecutionProvider_ROCM(rocm_options); -#endif - // The ONNX model is generated like the following: // python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 // --output tiny_gpt2_beamsearch_with_init_decoder_fp16.onnx -p fp16 --use_gpu --max_length 20 @@ -366,8 +349,7 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { constexpr int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); - if (enable_cuda || enable_rocm) { + if (enable_cuda) { Ort::SessionOptions session_options; #ifdef USE_CUDA OrtCUDAProviderOptionsV2 cuda_options; @@ -375,11 +357,6 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { session_options.AppendExecutionProvider_CUDA_V2(cuda_options); #endif -#ifdef USE_ROCM - OrtROCMProviderOptions rocm_options; - session_options.AppendExecutionProvider_ROCM(rocm_options); -#endif - // The following model was obtained by padding the vocabulary size in testdata/transformers/tiny_gpt2_beamsearch_fp16.onnx // from 1000 to 1600 (just for illustrative and testing purposes) to see if the beam search implementation can handle // such a scenario diff --git a/onnxruntime/test/contrib_ops/bias_add_op_test.cc b/onnxruntime/test/contrib_ops/bias_add_op_test.cc index 6fd091ef66110..1ec51631ca9ca 100644 --- a/onnxruntime/test/contrib_ops/bias_add_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_add_op_test.cc @@ -13,7 +13,7 @@ using namespace onnxruntime::test; namespace onnxruntime { namespace test { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_DML) static std::vector GetExpectedResult(const std::vector& input_data, const std::vector& bias_data, const std::vector& skip_data) { @@ -38,10 +38,9 @@ static void RunSkipBiasGpuTest(const std::vector& input_data, bool use_float16 = false) { int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()); - if (!enable_cuda && !enable_rocm && !enable_dml) { + if (!enable_cuda && !enable_dml) { return; } @@ -63,9 +62,7 @@ static void RunSkipBiasGpuTest(const std::vector& input_data, if (enable_cuda) { execution_providers.push_back(DefaultCudaExecutionProvider()); } - if (enable_rocm) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } + if (enable_dml) { execution_providers.push_back(DefaultDmlExecutionProvider()); } diff --git a/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc b/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc index 027d4b3fff1b0..4852269a5b6b6 100644 --- a/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// BiasDropout kernel is only implemented for CUDA/ROCM -#if (defined(USE_CUDA) && !defined(USE_CUDA_MINIMAL)) || defined(USE_ROCM) +// BiasDropout kernel is only implemented for CUDA +#if (defined(USE_CUDA) && !defined(USE_CUDA_MINIMAL)) #ifdef _MSC_VER #pragma warning(disable : 4389) @@ -17,23 +17,14 @@ #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" -#ifdef USE_ROCM -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#else #include "core/providers/cuda/shared_inc/cuda_utils.h" -#endif namespace onnxruntime { namespace contrib { namespace test { -#ifdef USE_ROCM -using onnxruntime::rocm::BitmaskElementType; -using onnxruntime::rocm::kNumBitsPerBitmaskElement; -#else using onnxruntime::cuda::BitmaskElementType; using onnxruntime::cuda::kNumBitsPerBitmaskElement; -#endif using namespace onnxruntime::test; enum TrainingMode { TrainingFalse, @@ -182,8 +173,6 @@ void RunBiasDropoutTest(const bool use_mask, const std::vector& input_s std::vector> t_eps; #ifdef USE_CUDA t_eps.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - t_eps.emplace_back(DefaultRocmExecutionProvider()); #endif t.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &t_eps); @@ -204,8 +193,6 @@ void RunBiasDropoutTest(const bool use_mask, const std::vector& input_s std::vector> t_bitmask_eps; #ifdef USE_CUDA t_bitmask_eps.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - t_bitmask_eps.emplace_back(DefaultRocmExecutionProvider()); #endif t_bitmask.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &t_bitmask_eps); } diff --git a/onnxruntime/test/contrib_ops/bias_softmax_op_test.cc b/onnxruntime/test/contrib_ops/bias_softmax_op_test.cc index bada23723e9f4..54c7cc16bc1a8 100644 --- a/onnxruntime/test/contrib_ops/bias_softmax_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_softmax_op_test.cc @@ -13,12 +13,6 @@ namespace onnxruntime { namespace test { -#if USE_ROCM -constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider; -#else -constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; -#endif - // followed example of fastgelu_op_test.cc // in retrospect would have been better to compare BiasSoftmax to Add + Softmax graph @@ -134,7 +128,7 @@ class BiasSoftmaxTester { void RunComparison() { // BiasSoftmax only implemented for cuda architecture int min_cuda_architecture = use_float16_ ? 530 : 0; - if (HasCudaEnvironment(min_cuda_architecture) || kGpuExecutionProvider == kRocmExecutionProvider) { + if (HasCudaEnvironment(min_cuda_architecture)) { OpTester tester("BiasSoftmax", 1, onnxruntime::kMSDomain); tester.AddAttribute("axis", axis_); tester.AddAttribute("is_inner_broadcast", is_inner_broadcast_); @@ -152,8 +146,6 @@ class BiasSoftmaxTester { std::vector> ep; #ifdef USE_CUDA ep.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - ep.push_back(DefaultRocmExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep); diff --git a/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc b/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc index a979717d23573..42db1c1201b63 100644 --- a/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc @@ -74,7 +74,7 @@ std::vector GetExpectedResult(const std::vector& input_data, } } // namespace bias_split_gelu_test -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_DML) static void RunBiasSplitGeluGpuTest(const std::vector& input_data, const std::vector& bias_data, @@ -85,10 +85,9 @@ static void RunBiasSplitGeluGpuTest(const std::vector& input_data, bool use_float16 = false) { int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()); - if (!enable_cuda && !enable_rocm && !enable_dml) { + if (!enable_cuda && !enable_dml) { return; } @@ -108,9 +107,7 @@ static void RunBiasSplitGeluGpuTest(const std::vector& input_data, if (enable_cuda) { execution_providers.push_back(DefaultCudaExecutionProvider()); } - if (enable_rocm) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } + if (enable_dml) { execution_providers.push_back(DefaultDmlExecutionProvider()); } diff --git a/onnxruntime/test/contrib_ops/bitmask_dropout_op_test.cc b/onnxruntime/test/contrib_ops/bitmask_dropout_op_test.cc index 7ca4e1004066c..926c45cadcc1b 100644 --- a/onnxruntime/test/contrib_ops/bitmask_dropout_op_test.cc +++ b/onnxruntime/test/contrib_ops/bitmask_dropout_op_test.cc @@ -1,29 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/common/tensor_op_test_utils.h" #include "test/util/include/default_providers.h" -#ifdef USE_ROCM -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#else #include "core/providers/cuda/shared_inc/cuda_utils.h" -#endif namespace onnxruntime { namespace contrib { namespace test { -#ifdef USE_ROCM -using onnxruntime::rocm::BitmaskElementType; -using onnxruntime::rocm::kNumBitsPerBitmaskElement; -#else using onnxruntime::cuda::BitmaskElementType; using onnxruntime::cuda::kNumBitsPerBitmaskElement; -#endif using namespace onnxruntime::test; namespace { @@ -62,8 +53,6 @@ void RunTestForInference(const std::vector& input_dims, bool has_ratio std::vector> test_eps; #ifdef USE_CUDA test_eps.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - test_eps.emplace_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_eps); } @@ -123,8 +112,6 @@ void RunTestForTraining(const std::vector& input_dims) { std::vector> dropout_eps; #ifdef USE_CUDA dropout_eps.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - dropout_eps.emplace_back(DefaultRocmExecutionProvider()); #endif dropout.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &dropout_eps); @@ -146,8 +133,6 @@ void RunTestForTraining(const std::vector& input_dims) { std::vector> bitmask_dropout_eps; #ifdef USE_CUDA bitmask_dropout_eps.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - bitmask_dropout_eps.emplace_back(DefaultRocmExecutionProvider()); #endif bitmask_dropout.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &bitmask_dropout_eps); } diff --git a/onnxruntime/test/contrib_ops/decoder_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_attention_op_test.cc index 8a37ef921fd2b..3864baa7c16e2 100644 --- a/onnxruntime/test/contrib_ops/decoder_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_attention_op_test.cc @@ -33,10 +33,9 @@ static void RunAttentionTest( const std::vector* value_cache = nullptr, const std::initializer_list* key_padding_mask_data = nullptr) { bool enable_cuda = HasCudaEnvironment(0); - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); bool enable_cpu = false; - if (enable_cpu || enable_cuda || enable_rocm) { + if (enable_cpu || enable_cuda) { OpTester tester("DecoderAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("mask_filter_value", static_cast(-10000.0f)); @@ -103,9 +102,7 @@ static void RunAttentionTest( if (enable_cuda) { execution_providers.push_back(DefaultCudaExecutionProvider()); } - if (enable_rocm) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } diff --git a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc index 38659fbd9f2b9..3fd27e2ead7c6 100644 --- a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc +++ b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc @@ -109,7 +109,7 @@ TEST(BiasGeluTest, Float) { RunBiasGeluTestFloat({2, 2333}, {2333}); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) +#if defined(USE_CUDA) || defined(USE_DML) || defined(USE_WEBGPU) static void RunBiasGeluTestHalf(const std::vector& input_dims, const std::vector& bias_dims) { RandomValueGenerator random{2333}; std::vector input_data = random.Uniform(input_dims, -1.0f, 1.0f); @@ -147,7 +147,7 @@ TEST(BiasGeluTest, MLFloat16) { } #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL) +#if defined(USE_CUDA) || defined(USE_DNNL) static void RunBiasGeluTestBFloat16(const std::vector& input_dims, const std::vector& bias_dims) { RandomValueGenerator random{2333}; std::vector input_data = random.Uniform(input_dims, 0.5f, 1.5f); @@ -164,8 +164,6 @@ static void RunBiasGeluTestBFloat16(const std::vector& input_dims, cons std::vector> execution_providers; #if defined(USE_CUDA) execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif defined(USE_ROCM) - execution_providers.push_back(DefaultRocmExecutionProvider()); #elif defined(USE_DNNL) execution_providers.push_back(DefaultDnnlExecutionProvider()); #elif defined(USE_DML) @@ -197,7 +195,7 @@ TEST(BiasGeluTest, BFloat16) { } #endif -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(MathOpTest, ComplexMul) { std::vector input_a_data = { -0.5f, 0.6f}; @@ -220,8 +218,6 @@ TEST(MathOpTest, ComplexMul) { std::vector> execution_providers; #if defined(USE_CUDA) execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif defined(USE_ROCM) - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -248,8 +244,6 @@ TEST(MathOpTest, ComplexMulConj) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif defined(USE_ROCM) - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -276,8 +270,6 @@ TEST(MathOpTest, ComplexMul_fp16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif defined(USE_ROCM) - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -304,8 +296,6 @@ TEST(MathOpTest, ComplexMulConj_fp16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif defined(USE_ROCM) - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } diff --git a/onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc index 9ecbb04ebccca..6cd84fc55ea86 100644 --- a/onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc @@ -17,11 +17,10 @@ static void RunTest(const embedlayernorm::OpData& data, int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); - bool enable_rocm = DefaultRocmExecutionProvider().get() != nullptr; bool enable_dml = DefaultDmlExecutionProvider().get() != nullptr; bool enable_cpu = !use_float16; - if (enable_cpu || enable_cuda || enable_dml || enable_rocm) { + if (enable_cpu || enable_cuda || enable_dml) { // Input and output shapes // Input 0 - input_ids : (batch_size, sequence_size) // Input 1 - segment_ids : (batch_size, sequence_size) @@ -149,10 +148,6 @@ static void RunTest(const embedlayernorm::OpData& data, std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } else if (enable_rocm) { - std::vector> execution_providers; - execution_providers.push_back(DefaultRocmExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else if (enable_dml) { std::vector> execution_providers; execution_providers.push_back(DefaultDmlExecutionProvider()); diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc index 497b8b5fd6cc7..3490516f32099 100644 --- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc @@ -41,7 +41,7 @@ const std::vector GetExpectedResult(const std::vector& input_data, return ComputeGelu(add_bias_data); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) +#if defined(USE_CUDA) || defined(USE_WEBGPU) static void RunFastGeluGpuTest(const std::vector& input_data, const std::vector& bias_data, const std::vector& output_data, const std::vector& input_dims, const std::vector& bias_dims, const std::vector& output_dims, @@ -73,8 +73,6 @@ static void RunFastGeluGpuTest(const std::vector& input_data, const std:: std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #elif USE_WEBGPU execution_providers.push_back(DefaultWebGpuExecutionProvider()); #endif @@ -144,7 +142,7 @@ static void RunFastGeluTest( std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector bias_dims = {hidden_size}; std::vector output_dims = input_dims; -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) +#if defined(USE_CUDA) || defined(USE_WEBGPU) RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); #endif RunFastGeluCpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); @@ -247,8 +245,8 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat32) { RunFastGeluTest(input_data, bias_data, batch_size, sequence_length, hidden_size); } -// CUDA, ROCm and WebGPU only for Float16 type. -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) +// CUDA and WebGPU only for Float16 type. +#if defined(USE_CUDA) || defined(USE_WEBGPU) TEST(FastGeluTest, FastGeluWithBiasFloat16_2) { int batch_size = 1; int sequence_length = 2; @@ -385,8 +383,8 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) { } #endif -// CUDA and ROCm only for BFloat16 type. -#if defined(USE_CUDA) || defined(USE_ROCM) +// CUDA only for BFloat16 type. +#if defined(USE_CUDA) TEST(FastGeluTest, FastGeluWithBias_BFloat16) { #ifdef USE_CUDA int min_cuda_architecture = 800; @@ -433,15 +431,13 @@ TEST(FastGeluTest, FastGeluWithBias_BFloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } #endif -// CUDA and ROCm only for double type. -#if defined(USE_CUDA) || defined(USE_ROCM) +// CUDA only for double type. +#if defined(USE_CUDA) TEST(FastGeluTest, FastGeluWithBias_Double) { OpTester tester("FastGelu", 1, onnxruntime::kMSDomain); @@ -471,8 +467,6 @@ TEST(FastGeluTest, FastGeluWithBias_Double) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } diff --git a/onnxruntime/test/contrib_ops/fft_op_test.cc b/onnxruntime/test/contrib_ops/fft_op_test.cc index 7a6b6cca6425a..1a75f2b12d5eb 100644 --- a/onnxruntime/test/contrib_ops/fft_op_test.cc +++ b/onnxruntime/test/contrib_ops/fft_op_test.cc @@ -8,15 +8,12 @@ namespace onnxruntime { namespace test { TEST(ContribOpTest, Rfft) { - if (DefaultCudaExecutionProvider() == nullptr && DefaultRocmExecutionProvider() == nullptr) return; + if (DefaultCudaExecutionProvider() == nullptr) return; std::vector> execution_providers; if (DefaultCudaExecutionProvider() != nullptr) { execution_providers.push_back(DefaultCudaExecutionProvider()); } - if (DefaultRocmExecutionProvider() != nullptr) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } OpTester test("Rfft", 1, onnxruntime::kMSDomain); test.AddAttribute("signal_ndim", static_cast(1)); @@ -30,15 +27,12 @@ TEST(ContribOpTest, Rfft) { } TEST(ContribOpTest, Irfft) { - if (DefaultCudaExecutionProvider() == nullptr && DefaultRocmExecutionProvider() == nullptr) return; + if (DefaultCudaExecutionProvider() == nullptr) return; std::vector> execution_providers; if (DefaultCudaExecutionProvider() != nullptr) { execution_providers.push_back(DefaultCudaExecutionProvider()); } - if (DefaultRocmExecutionProvider() != nullptr) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } OpTester test("Irfft", 1, onnxruntime::kMSDomain); test.AddAttribute("signal_ndim", static_cast(1)); diff --git a/onnxruntime/test/contrib_ops/fused_conv_test.cc b/onnxruntime/test/contrib_ops/fused_conv_test.cc index 0dd69a49972e8..9df222db43501 100644 --- a/onnxruntime/test/contrib_ops/fused_conv_test.cc +++ b/onnxruntime/test/contrib_ops/fused_conv_test.cc @@ -32,17 +32,14 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, const vector& expected_output_shape, bool disable_cpu = false, bool disable_cuda = false, - bool disable_rocm = false, bool disable_webgpu = false, bool use_float16 = false, bool weight_is_initializer = false) { bool enable_cuda = HasCudaEnvironment(0) && !use_float16 && !disable_cuda; - // Only ROCm EP supports float16. - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !disable_rocm; bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()) && !disable_webgpu; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu; - if (enable_cuda || enable_rocm || enable_cpu || enable_webgpu) { + if (enable_cuda || enable_cpu || enable_webgpu) { OpTester test("FusedConv", 1, onnxruntime::kMSDomain); test.AddAttribute("group", attributes.group); test.AddAttribute("kernel_shape", attributes.kernel_shape); @@ -94,10 +91,6 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, execution_providers.push_back(DefaultCudaExecutionProvider()); } - if (enable_rocm) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } - if (enable_webgpu) { execution_providers.push_back(DefaultWebGpuExecutionProvider()); } @@ -116,16 +109,15 @@ void RunConvOp(const ConvOpAndTestAttributes& attributes, const vector& expected_output_shape, bool disable_cpu = false, bool disable_cuda = false, - bool disable_rocm = false, bool disable_webgpu = false) { bool weight_is_initializer = false; bool use_float16 = false; TestConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, - disable_cpu, disable_cuda, disable_rocm, disable_webgpu, use_float16, weight_is_initializer); + disable_cpu, disable_cuda, disable_webgpu, use_float16, weight_is_initializer); use_float16 = true; TestConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, - disable_cpu, disable_cuda, disable_rocm, disable_webgpu, use_float16, weight_is_initializer); + disable_cpu, disable_cuda, disable_webgpu, use_float16, weight_is_initializer); } TEST(FusedConvTest, Conv2D_HardSigmoid) { @@ -146,7 +138,7 @@ TEST(FusedConvTest, Conv2D_HardSigmoid) { vector W_shape = {2, 1, 2, 2}; vector Y_shape = {1, 2, 2, 2}; auto expected_vals = {0.8f, 0.9f, 1.0f, 1.0f, 0.2f, 0.1f, 0.0f, 0.0f}; - RunConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, false, true, true, true); + RunConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, false, true, true); } TEST(FusedConvTest, Conv2D_Relu) { @@ -191,7 +183,7 @@ TEST(FusedConvTest, Conv2D_Bias_Relu) { RunConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(FusedConvTest, Conv2D_Bias_Z_Relu) { ConvOpAndTestAttributes attrs = { @@ -214,7 +206,7 @@ TEST(FusedConvTest, Conv2D_Bias_Z_Relu) { vector Z = {-1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}; vector Z_shape = {1, 2, 2, 2}; auto expected_vals = {12.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 28.0f}; - RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, true, false, false); + RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, true, false); } #endif @@ -240,7 +232,7 @@ TEST(FusedConvTest, Cpu_Conv2D_Bias_Z_Relu) { vector Z = {-1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}; vector Z_shape = {1, 2, 2, 2}; auto expected_vals = {12.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 28.0f}; - RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, false, true, true, true); + RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, false, true, true); } #endif diff --git a/onnxruntime/test/contrib_ops/fused_matmul_op_test.cc b/onnxruntime/test/contrib_ops/fused_matmul_op_test.cc index b1762d16795d1..8b15ac5300a82 100644 --- a/onnxruntime/test/contrib_ops/fused_matmul_op_test.cc +++ b/onnxruntime/test/contrib_ops/fused_matmul_op_test.cc @@ -221,7 +221,7 @@ TEST(FusedMatMulOpTest, FloatTypeNoTranspose) { RunFusedMatMulTest("FusedMatMul", 1); } -#if defined(USE_CUDA) || defined(USE_ROCM) // double support only implemented in CUDA/ROCM kernel +#if defined(USE_CUDA) // double support only implemented in CUDA kernel TEST(FusedMatMulOpTest, DoubleTypeNoTranspose) { RunFusedMatMulTest("FusedMatMul", 1); } @@ -270,7 +270,7 @@ TEST(FusedMatMulOpTest, FloatTypeTransposeBatch) { RunFusedMatMulTest("FusedMatMul", 1, true, true, true, true); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_DML) TEST(FusedMatMulOpTest, Float16_NoTranspose) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -315,7 +315,7 @@ TEST(FusedMatMulOpTest, Float16_NoTranspose) { } #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL) +#if defined(USE_CUDA) || defined(USE_DNNL) TEST(FusedMatMulOpTest, BFloat16_NoTranspose) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -371,8 +371,6 @@ TEST(FusedMatMulOpTest, BFloat16_NoTranspose) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #elif USE_DNNL execution_providers.push_back(DefaultDnnlExecutionProvider()); #endif diff --git a/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc deleted file mode 100644 index 6b67b648fd9b2..0000000000000 --- a/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc +++ /dev/null @@ -1,245 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include "core/platform/threadpool.h" -#include "core/util/math.h" -#include "core/util/thread_utils.h" -#include "test/common/cuda_op_test_utils.h" -#include "test/common/tensor_op_test_utils.h" -#include "test/providers/provider_test_utils.h" - -namespace onnxruntime { -namespace test { -namespace gemmfastgelu { - -#if defined(USE_ROCM) -namespace { - -const onnxruntime::RunOptions run_options = []() { - onnxruntime::RunOptions options{}; - ORT_THROW_IF_ERROR(options.config_options.AddConfigEntry(kOpTesterRunOptionsConfigTestTunableOp, "true")); - return options; -}(); - -const constexpr auto run_with_tunable_op = &run_options; - -} // namespace - -static void RunGemmFastGeluGpuTest(const std::vector& input_data, const std::vector& weight_data, - const std::vector& bias_data, const std::vector& output_data, - const std::vector& input_dims, const std::vector& weight_dims, - const std::vector& bias_dims, const std::vector& output_dims, - bool has_bias, bool use_float16 = false) { - OpTester tester("GemmFastGelu", 1, onnxruntime::kMSDomain); - - if (use_float16) { - tester.AddInput("X", input_dims, ToFloat16(input_data)); - tester.AddInput("W", weight_dims, ToFloat16(weight_data)); - if (has_bias) { - tester.AddInput("bias", bias_dims, ToFloat16(bias_data)); - } - tester.AddOutput("Y", output_dims, ToFloat16(output_data)); - } else { - tester.AddInput("X", input_dims, input_data); - tester.AddInput("W", weight_dims, weight_data); - if (has_bias) { - tester.AddInput("bias", bias_dims, bias_data); - } - tester.AddOutput("Y", output_dims, output_data); - } - - tester.SetOutputTolerance(use_float16 ? 0.005f : 0.0025f); - - tester.Config(run_with_tunable_op) - .RunWithConfig(); -} - -TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat32) { - int batch_size = 1; - int sequence_length = 2; - int hidden_size = 4; - int dense_size = 6; - - std::vector input_data = { - 0.8f, -0.5f, 0.0f, 1.f, - 0.5f, 0.2f, 0.3f, -0.6f}; - - std::vector weight_data = { - 0.8f, -0.5f, 0.0f, 1.f, - 0.5f, 0.2f, 0.3f, -0.6f, - 0.7f, -0.5f, 0.7f, 1.2f, - 0.3f, 0.1f, 0.8f, -1.6f, - 0.9f, -0.1f, 3.0f, 2.f, - 0.4f, -0.7f, -0.3f, 0.6f}; - - std::vector bias_data = {}; - - std::vector output_data = { - 3.4894f, 1.8455f, 0.0260f, 0.2229f, -0.1003f, 0.0902f, - -0.1323f, -0.0953f, 0.0778f, 0.2152f, 0.6715f, -0.0240f}; - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weight_dims = {hidden_size, dense_size}; - std::vector bias_dims = {dense_size}; - std::vector output_dims = {batch_size, sequence_length, dense_size}; - - RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data, - input_dims, weight_dims, bias_dims, output_dims, - false); -} - -TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat32) { - int batch_size = 1; - int sequence_length = 2; - int hidden_size = 4; - int dense_size = 6; - - std::vector input_data = { - 0.8f, -0.5f, 0.0f, 1.f, - 0.5f, 0.2f, 0.3f, -0.6f}; - - std::vector weight_data = { - 0.8f, -0.5f, 0.0f, 1.f, - 0.5f, 0.2f, 0.3f, -0.6f, - 0.7f, -0.5f, 0.7f, 1.2f, - 0.3f, 0.1f, 0.8f, -1.6f, - 0.9f, -0.1f, 3.0f, 2.f, - 0.4f, -0.7f, -0.3f, 0.6f}; - - std::vector bias_data = { - -0.5f, 0.6f, 1.2f, 2.1f, -0.6f, 0.4f}; - - std::vector output_data = { - 2.9862f, 2.4849f, 1.1177f, 2.4329f, -0.1681f, 0.3988f, - -0.0702f, -0.1633f, 1.2190f, 2.4225f, 0.1428f, 0.2229f}; - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weight_dims = {hidden_size, dense_size}; - std::vector bias_dims = {dense_size}; - std::vector output_dims = {batch_size, sequence_length, dense_size}; - - RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data, - input_dims, weight_dims, bias_dims, output_dims, - true); -} - -TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat16) { - int batch_size = 1; - int sequence_length = 2; - int hidden_size = 4; - int dense_size = 6; - - std::vector input_data = { - 0.8f, -0.5f, 0.0f, 1.f, - 0.5f, 0.2f, 0.3f, -0.6f}; - - std::vector weight_data = { - 0.8f, -0.5f, 0.0f, 1.f, - 0.5f, 0.2f, 0.3f, -0.6f, - 0.7f, -0.5f, 0.7f, 1.2f, - 0.3f, 0.1f, 0.8f, -1.6f, - 0.9f, -0.1f, 3.0f, 2.f, - 0.4f, -0.7f, -0.3f, 0.6f}; - - std::vector bias_data = {}; - - std::vector output_data = { - 3.4902f, 1.8467f, 0.0259f, 0.2227f, -0.1005f, 0.0901f, - -0.1324f, -0.0955f, 0.0778f, 0.2156f, 0.6714f, -0.0241f}; - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weight_dims = {hidden_size, dense_size}; - std::vector bias_dims = {dense_size}; - std::vector output_dims = {batch_size, sequence_length, dense_size}; - - RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data, - input_dims, weight_dims, bias_dims, output_dims, - false, true); -} - -TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat16) { - int batch_size = 1; - int sequence_length = 2; - int hidden_size = 4; - int dense_size = 6; - - std::vector input_data = { - 0.8f, -0.5f, 0.0f, 1.f, - 0.5f, 0.2f, 0.3f, -0.6f}; - - std::vector weight_data = { - 0.8f, -0.5f, 0.0f, 1.f, - 0.5f, 0.2f, 0.3f, -0.6f, - 0.7f, -0.5f, 0.7f, 1.2f, - 0.3f, 0.1f, 0.8f, -1.6f, - 0.9f, -0.1f, 3.0f, 2.f, - 0.4f, -0.7f, -0.3f, 0.6f}; - - std::vector bias_data = { - -0.5f, 0.6f, 1.2f, 2.1f, -0.6f, 0.4f}; - - std::vector output_data = { - 2.9883f, 2.4844f, 1.1182f, 2.4316f, -0.1680f, 0.3984f, - -0.0701f, -0.1633f, 1.2178f, 2.4219f, 0.1426f, 0.2227f}; - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weight_dims = {hidden_size, dense_size}; - std::vector bias_dims = {dense_size}; - std::vector output_dims = {batch_size, sequence_length, dense_size}; - - RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data, - input_dims, weight_dims, bias_dims, output_dims, - true, true); -} - -TEST(GemmFastGeluTest, GemmFastGeluWithBias_bfloat16) { - OpTester tester("GemmFastGelu", 1, onnxruntime::kMSDomain); - - int batch_size = 1; - int sequence_length = 2; - int hidden_size = 4; - int dense_size = 6; - - std::vector input_data = { - 0.8f, -0.5f, 0.0f, 1.f, - 0.5f, 0.2f, 0.3f, -0.6f}; - - std::vector weight_data = { - 0.8f, -0.5f, 0.0f, 1.f, - 0.5f, 0.2f, 0.3f, -0.6f, - 0.7f, -0.5f, 0.7f, 1.2f, - 0.3f, 0.1f, 0.8f, -1.6f, - 0.9f, -0.1f, 3.0f, 2.f, - 0.4f, -0.7f, -0.3f, 0.6f}; - - std::vector bias_data = { - -0.5f, 0.6f, 1.2f, 2.1f, -0.6f, 0.4f}; - - std::vector output_data = { - 2.9883f, 2.4844f, 1.1182f, 2.4316f, -0.1680f, 0.3984f, - -0.0701f, -0.1633f, 1.2178f, 2.4219f, 0.1426f, 0.2227f}; - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weight_dims = {hidden_size, dense_size}; - std::vector bias_dims = {dense_size}; - std::vector output_dims = {batch_size, sequence_length, dense_size}; - - std::vector f_X = FloatsToBFloat16s(input_data); - std::vector f_W = FloatsToBFloat16s(weight_data); - std::vector f_B = FloatsToBFloat16s(bias_data); - std::vector f_Y = FloatsToBFloat16s(output_data); - - tester.AddInput("X", input_dims, f_X); - tester.AddInput("W", weight_dims, f_W); - tester.AddInput("bias", bias_dims, f_B); - tester.AddOutput("Y", output_dims, f_Y); - - tester.Config(run_with_tunable_op) - .RunWithConfig(); -} -#endif - -} // namespace gemmfastgelu -} // namespace test -} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/greedy_search_test.cc b/onnxruntime/test/contrib_ops/greedy_search_test.cc index be72fbd460c9b..04a57dd9c9b2c 100644 --- a/onnxruntime/test/contrib_ops/greedy_search_test.cc +++ b/onnxruntime/test/contrib_ops/greedy_search_test.cc @@ -60,13 +60,8 @@ TEST(GreedySearchTest, GptGreedySearchFp16_VocabPadded) { #else bool is_cuda = false; #endif -#ifdef USE_ROCM - bool is_rocm = true; -#else - bool is_rocm = false; -#endif - if (is_cuda || is_rocm) { + if (is_cuda) { Ort::SessionOptions session_options; #ifdef USE_CUDA if (is_cuda) { @@ -142,13 +137,8 @@ TEST(GreedySearchTest, GptGreedySearchFp32) { #else bool is_cuda = false; #endif -#ifdef USE_ROCM - bool is_rocm = true; -#else - bool is_rocm = false; -#endif - if (is_cuda || is_rocm) { + if (is_cuda) { Ort::SessionOptions session_options; #ifdef USE_CUDA if (is_cuda) { diff --git a/onnxruntime/test/contrib_ops/group_norm_op_test.cc b/onnxruntime/test/contrib_ops/group_norm_op_test.cc index fdc546441676b..5227509368f45 100644 --- a/onnxruntime/test/contrib_ops/group_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_norm_op_test.cc @@ -730,20 +730,17 @@ TEST(GroupNormTest, GroupNorm_128) { // Test float16, without activation int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()); std::array channels_last_values = {-1, 0, 1}; for (const int channels_last : channels_last_values) { - if (enable_cuda || enable_rocm || enable_dml) { + if (enable_cuda || enable_dml) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); } - if (enable_rocm && channels_last != 0) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } + if (enable_dml) { execution_providers.push_back(DefaultDmlExecutionProvider()); } @@ -784,14 +781,12 @@ TEST(GroupNormTest, GroupNorm_128) { // Test float32, with activation enable_cuda = HasCudaEnvironment(0); - if (enable_cuda || enable_rocm || enable_dml) { + if (enable_cuda || enable_dml) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); } - if (enable_rocm && channels_last != 0) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } + if (enable_dml) { execution_providers.push_back(DefaultDmlExecutionProvider()); } diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 0d4fc5af68b4f..d08df321a963b 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -21,7 +21,7 @@ using namespace std; namespace onnxruntime { namespace test { -// Some feature (like broadcast support) are implemented in CPU and CUDA/ROCM provider only. A helper to run tests. +// Some feature (like broadcast support) are implemented in CPU and CUDA provider only. A helper to run tests. void RunTestOnCpuAndCuda(OpTester& test, const std::string& expected_failure_msg = "") { auto expected_result = expected_failure_msg.empty() ? OpTester::ExpectResult::kExpectSuccess @@ -33,13 +33,11 @@ void RunTestOnCpuAndCuda(OpTester& test, const std::string& expected_failure_msg constexpr int min_cuda_architecture = 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); - if (enable_cuda || enable_rocm) { + + if (enable_cuda) { std::vector> gpu_execution_provider; if (enable_cuda) { gpu_execution_provider.push_back(DefaultCudaExecutionProvider()); - } else if (enable_rocm) { - gpu_execution_provider.push_back(DefaultRocmExecutionProvider()); } if (gpu_execution_provider.size() > 0) { diff --git a/onnxruntime/test/contrib_ops/layer_norm_test.cc b/onnxruntime/test/contrib_ops/layer_norm_test.cc index 46082e1b0cd31..75e1e0856bc7e 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_test.cc @@ -6,7 +6,7 @@ namespace onnxruntime { namespace test { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) +#if defined(USE_CUDA) || defined(USE_DML) || defined(USE_WEBGPU) constexpr auto k_epsilon_default = 1e-5f; constexpr auto k_random_data_min = -10.0f; constexpr auto k_random_data_max = 10.0f; @@ -80,8 +80,6 @@ static void TestLayerNorm(const std::vector& x_dims, #ifdef USE_CUDA test.CompareWithCPU(kCudaExecutionProvider); -#elif USE_ROCM - test.CompareWithCPU(kRocmExecutionProvider); #elif USE_DML test.CompareWithCPU(kDmlExecutionProvider); #elif USE_WEBGPU diff --git a/onnxruntime/test/contrib_ops/longformer_attention_op_test.cc b/onnxruntime/test/contrib_ops/longformer_attention_op_test.cc index 3e5c9a10f32b5..c7c03ed7080ae 100644 --- a/onnxruntime/test/contrib_ops/longformer_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/longformer_attention_op_test.cc @@ -29,7 +29,6 @@ static void RunAttentionTest( int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); bool enable_cpu = false; if (enable_cpu || enable_cuda) { OpTester tester("LongformerAttention", 1, onnxruntime::kMSDomain); @@ -69,9 +68,7 @@ static void RunAttentionTest( if (enable_cuda) { execution_providers.push_back(DefaultCudaExecutionProvider()); } - if (enable_rocm) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } + if (enable_cpu) { execution_providers.push_back(DefaultCpuExecutionProvider()); } diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index e5cfb946999a6..21e2003bf9acf 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -502,7 +502,7 @@ TEST(MatMulNBits, LegacyShape_4b) { #endif #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) +#if defined(USE_CUDA) || defined(USE_DML) || defined(USE_WEBGPU) namespace { // Legacy test function. @@ -538,10 +538,6 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop execution_providers.push_back(DefaultCudaExecutionProvider()); RunTest(opts, std::move(execution_providers)); #endif -#ifdef USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); - RunTest(opts, std::move(execution_providers)); -#endif #ifdef USE_DML execution_providers.push_back(DefaultDmlExecutionProvider()); RunTest(opts, std::move(execution_providers)); @@ -551,9 +547,6 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop RunTest(opts, std::move(execution_providers)); #endif } else { -#ifdef USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); -#endif #ifdef USE_WEBGPU ConfigOptions config_options{}; ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kMaxStorageBufferBindingSize, "134217728").IsOK()); @@ -737,7 +730,7 @@ TEST(MatMulNBits, BFloat16_Int4_NoZeroPoint) { } #endif -#endif // defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#endif // defined(USE_CUDA) || defined(USE_DML) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index a7d5f15698f0c..c740959105977 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -9,18 +9,6 @@ #include "test/util/include/scoped_env_vars.h" #include "test/contrib_ops/attention_op_test_helper.h" -#if defined(USE_ROCM) && defined(USE_COMPOSABLE_KERNEL) && !defined(USE_MIGRAPHX) -#define DISABLE_ROCM false -#else -#define DISABLE_ROCM true -#endif - -#if defined(USE_ROCM) -#define ROCM_GTEST_SKIP(message) GTEST_SKIP_(message) -#else -#define ROCM_GTEST_SKIP(message) -#endif - namespace onnxruntime { namespace test { @@ -57,30 +45,17 @@ static void RunMultiHeadAttentionTest( bool disable_cpu = false, // some cases not supported in cpu right now. bool disable_cuda = false, bool disable_webgpu = false, - bool disable_rocm = DISABLE_ROCM, // not supported in rocm right now. bool disable_dml = false) { kv_sequence_length = (kv_sequence_length == 0 ? sequence_length : kv_sequence_length); int past_sequence_length = (past_seq_len_data.size() == 0) ? 0 : past_seq_len_data[0]; int min_cuda_architecture = use_float16 ? 750 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !disable_cuda; - // rocm mha is required to work with TunableOp Enabled - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider(/*test_tunable_op=*/true).get()) && !disable_rocm; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu; bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()) && !disable_webgpu; - if (enable_rocm && !use_float16) { - LOGS_DEFAULT(WARNING) << "ROCm MHA only have kernel for half datatype implemented, skip float datatype tests"; - enable_rocm = false; - } - - if (enable_rocm && !bias_data.empty()) { - LOGS_DEFAULT(WARNING) << "ROCm MHA does not support qkv_bias, skip qkv_bias tests"; - enable_rocm = false; - } - - if (enable_cpu || enable_cuda || enable_rocm || enable_dml || enable_webgpu) { + if (enable_cpu || enable_cuda || enable_dml || enable_webgpu) { OpTester tester("MultiHeadAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("unidirectional", static_cast(is_unidirectional ? 1 : 0)); @@ -301,12 +276,6 @@ static void RunMultiHeadAttentionTest( tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (enable_rocm) { - std::vector> execution_providers; - execution_providers.push_back(DefaultRocmExecutionProvider(/*test_tunable_op=*/true)); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (enable_cpu) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); @@ -361,7 +330,6 @@ static void RunMultiHeadAttentionKernel( bool disable_cpu = false, // some cases not supported in cpu right now. bool disable_cuda = false, bool disable_webgpu = false, - bool disable_rocm = DISABLE_ROCM, bool disable_dml = false) { if (kernel_type == AttentionKernelType::AttentionKernel_Default) { ScopedEnvironmentVariables scoped_env_vars{ @@ -377,7 +345,7 @@ static void RunMultiHeadAttentionKernel( present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, output_qk_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, num_beams, max_sequence_length, is_static_kv, buffer_share, use_float16, - is_unidirectional, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + is_unidirectional, disable_cpu, disable_cuda, disable_webgpu, disable_dml); return; } @@ -395,7 +363,7 @@ static void RunMultiHeadAttentionKernel( present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, output_qk_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, num_beams, max_sequence_length, is_static_kv, buffer_share, use_float16, - is_unidirectional, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + is_unidirectional, disable_cpu, disable_cuda, disable_webgpu, disable_dml); return; } @@ -413,7 +381,7 @@ static void RunMultiHeadAttentionKernel( present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, output_qk_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, num_beams, max_sequence_length, is_static_kv, buffer_share, use_float16, - is_unidirectional, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + is_unidirectional, disable_cpu, disable_cuda, disable_webgpu, disable_dml); return; } @@ -432,7 +400,7 @@ static void RunMultiHeadAttentionKernel( present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, output_qk_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, num_beams, max_sequence_length, is_static_kv, buffer_share, use_float16, - is_unidirectional, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + is_unidirectional, disable_cpu, disable_cuda, disable_webgpu, disable_dml); return; } #endif @@ -452,7 +420,7 @@ static void RunMultiHeadAttentionKernel( present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, output_qk_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, num_beams, max_sequence_length, is_static_kv, buffer_share, use_float16, - is_unidirectional, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + is_unidirectional, disable_cpu, disable_cuda, disable_webgpu, disable_dml); } if (kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { @@ -470,7 +438,7 @@ static void RunMultiHeadAttentionKernel( present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, output_qk_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, num_beams, max_sequence_length, is_static_kv, buffer_share, use_float16, - is_unidirectional, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + is_unidirectional, disable_cpu, disable_cuda, disable_webgpu, disable_dml); } } @@ -479,7 +447,6 @@ enum RunMultiHeadAttentionTestToggles : uint32_t { DISABLE_CPU = 1 << 0, DISABLE_CUDA = 1 << 1, DISABLE_WEBGPU = 1 << 2, - DISABLE_ROCM_MHA = 1 << 3, DISABLE_DML = 1 << 4, }; inline RunMultiHeadAttentionTestToggles operator|(RunMultiHeadAttentionTestToggles a, RunMultiHeadAttentionTestToggles b) { @@ -494,7 +461,6 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu = toggles & DISABLE_CPU; bool disable_cuda = toggles & DISABLE_CUDA; bool disable_webgpu = toggles & DISABLE_WEBGPU; - bool disable_rocm = toggles & DISABLE_ROCM_MHA; bool disable_dml = toggles & DISABLE_DML; if (data.fp32_output_data.size() > 0) { @@ -508,7 +474,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.fp32_output_qk_data, kernel_type, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, data.num_beams, data.max_sequence_length, data.is_static_kv, data.buffer_share, use_float16, - false, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + false, disable_cpu, disable_cuda, disable_webgpu, disable_dml); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -522,7 +488,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.fp32_output_qk_data, kernel_type, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, data.num_beams, data.max_sequence_length, data.is_static_kv, data.buffer_share, use_float16, - false, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + false, disable_cpu, disable_cuda, disable_webgpu, disable_dml); } } #endif @@ -534,7 +500,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.fp32_output_qk_data, kernel_type, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, data.num_beams, data.max_sequence_length, data.is_static_kv, data.buffer_share, use_float16, - false, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + false, disable_cpu, disable_cuda, disable_webgpu, disable_dml); } if (data.fp16_output_data.size() > 0) { @@ -547,7 +513,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.fp16_output_qk_data, kernel_type, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, data.num_beams, data.max_sequence_length, data.is_static_kv, data.buffer_share, use_float16, - false, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + false, disable_cpu, disable_cuda, disable_webgpu, disable_dml); } kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention; @@ -558,7 +524,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.fp16_output_qk_data, kernel_type, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, data.num_beams, data.max_sequence_length, data.is_static_kv, data.buffer_share, use_float16, - false, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + false, disable_cpu, disable_cuda, disable_webgpu, disable_dml); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -570,7 +536,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.fp16_output_qk_data, kernel_type, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, data.num_beams, data.max_sequence_length, data.is_static_kv, data.buffer_share, use_float16, - false, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + false, disable_cpu, disable_cuda, disable_webgpu, disable_dml); } #endif @@ -582,7 +548,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.fp16_output_qk_data, kernel_type, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, data.num_beams, data.max_sequence_length, data.is_static_kv, data.buffer_share, use_float16, - false, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + false, disable_cpu, disable_cuda, disable_webgpu, disable_dml); } kernel_type = AttentionKernelType::AttentionKernel_Default; @@ -592,14 +558,13 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.fp16_output_qk_data, kernel_type, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, data.num_beams, data.max_sequence_length, data.is_static_kv, data.buffer_share, use_float16, - false, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); + false, disable_cpu, disable_cuda, disable_webgpu, disable_dml); } } // Test fused cross attention kernel // It requires head_size > 32 and head_size <= 64 for T4 GPU; hidden_size == v_hidden_size. TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) { - ROCM_GTEST_SKIP("ROCm MHA skip - missing support for ROCm on Radeon"); AttentionTestData data; GetCrossAttentionData_HeadSize40(data); RunMultiHeadAttentionTests(data); @@ -609,7 +574,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) { } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask1D) { - ROCM_GTEST_SKIP("ROCm MHA does not support mask type of MASK_1D_KEY_SEQ_LEN"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true); RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); @@ -619,7 +583,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) { - ROCM_GTEST_SKIP("ROCm MHA skip - missing support for ROCm on Radeon"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false); RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); @@ -629,7 +592,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) { - ROCM_GTEST_SKIP("ROCm MHA skip - missing support for ROCm on Radeon"); AttentionTestData data; GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); @@ -639,14 +601,12 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Ma } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) { - ROCM_GTEST_SKIP("ROCm MHA skip - missing support for ROCm on Radeon"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data); RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) { - ROCM_GTEST_SKIP("ROCm MHA skip - missing support for ROCm on Radeon"); AttentionTestData data; GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data); RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); @@ -654,7 +614,6 @@ TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_Packe // This tests qk_head_size != v_head_size TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) { - ROCM_GTEST_SKIP("ROCm MHA skip - missing support for ROCm on Radeon"); AttentionTestData data; GetCrossAttentionData_HeadSize16_8(data); RunMultiHeadAttentionTests(data); @@ -664,7 +623,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) { } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { - ROCM_GTEST_SKIP("ROCm MHA skip - missing support for ROCm on Radeon"); AttentionTestData data; GetCrossAttentionData_HeadSize16(data); RunMultiHeadAttentionTests(data); @@ -674,7 +632,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize8) { - ROCM_GTEST_SKIP("ROCm MHA skip - missing support for ROCm on Radeon"); AttentionTestData data; GetCrossAttentionData_HeadSize8_NoBias(data); RunMultiHeadAttentionTests(data, DISABLE_CUDA); @@ -684,7 +641,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize8) { // Bug #50220930 #ifndef USE_DML TEST(MultiHeadAttentionTest, CrossAttentionWithPast) { - ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetCrossAttentionDataWithPast(data); RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); @@ -692,22 +648,18 @@ TEST(MultiHeadAttentionTest, CrossAttentionWithPast) { #endif TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithAttnBias_ForT5) { - ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetSelfAttentionData_WithPast_WithAttnBias_ForT5(data); RunMultiHeadAttentionTests(data, DISABLE_CPU); } TEST(MultiHeadAttentionTest, AttentionCutlassRelPosBias) { - ROCM_GTEST_SKIP("ROCm MHA skip - missing support for ROCm on Radeon"); - // ROCM_GTEST_SKIP("ROCm does not support cutlass"); AttentionTestData data; GetAttentionDataCutlassAttnBias(data); RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { - ROCM_GTEST_SKIP("ROCm MHA skip - missing support for ROCm on Radeon"); // Whisper decoder cross attention without mask and different sequence lengths for Q and K/V AttentionTestData data; GetCrossAttentionData_DiffSequenceLengths(data); @@ -721,7 +673,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { } TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoRelPosBias) { - ROCM_GTEST_SKIP("ROCm MHA skip - missing support for ROCm on Radeon"); // Whisper decoder self attention with past_kv and present_kv AttentionTestData data; GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(data); @@ -734,7 +685,7 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoRelPosBia RunMultiHeadAttentionTests(data, DISABLE_CUDA); } -// This test is disabled since it is not used in Whisper anymore, and it fails in ROCm. +// This test is disabled since it is not used in Whisper anymore. TEST(MultiHeadAttentionTest, DISABLED_CrossAttention_WithPastPassedInDirectly_NoMask) { // Whisper decoder cross attention with past_kv in place of current KV and no present_kv AttentionTestData data; @@ -749,7 +700,7 @@ TEST(MultiHeadAttentionTest, SelfAttention_PastPresentBufferShare_UsingDMMHAInsi // See onnxruntime/core/graph/contrib_ops/bert_defs.cc for more details AttentionTestData data; GetSelfAttention_PastPresentBufferShare_UsingDMMHAInsideMHA(data); - RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_ROCM_MHA | DISABLE_WEBGPU | DISABLE_DML); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU | DISABLE_DML); } TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths_UsingDMMHAInsideMHA) { @@ -757,7 +708,7 @@ TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths_UsingDMMHAInside // Used in decoder-with-past's cross-attention layers AttentionTestData data; GetCrossAttention_DiffSequenceLengths_UsingDMMHAInsideMHA(data); - RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_ROCM_MHA | DISABLE_WEBGPU | DISABLE_DML); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU | DISABLE_DML); } } // namespace test diff --git a/onnxruntime/test/contrib_ops/ngram_repeat_block_op_test.cc b/onnxruntime/test/contrib_ops/ngram_repeat_block_op_test.cc index 09b98aa50bd7a..f57882473e30b 100644 --- a/onnxruntime/test/contrib_ops/ngram_repeat_block_op_test.cc +++ b/onnxruntime/test/contrib_ops/ngram_repeat_block_op_test.cc @@ -31,13 +31,6 @@ TEST(NGramRepeatBlockTest, NGramSize_3) { execution_providers.push_back(DefaultCudaExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -#ifdef USE_ROCM - if (nullptr != DefaultRocmExecutionProvider().get()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultRocmExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -#endif std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); diff --git a/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc index e780d35df08bd..850bea4351914 100644 --- a/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc +++ b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc @@ -32,10 +32,9 @@ void TestNhwcConvOp(const NhwcConvOpAndTestAttributes& attributes, int min_cuda_architecture = use_float16 ? 530 : 0; // NHWC implementation doesn't handle W in NHWC layout if it's not an initializer bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && weight_is_initializer; - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()); - if (enable_cuda || enable_rocm || enable_dml) { + if (enable_cuda || enable_dml) { OpTester test("NhwcConv", 1, onnxruntime::kMSDomain); test.AddAttribute("group", attributes.group); test.AddAttribute("kernel_shape", attributes.kernel_shape); @@ -80,10 +79,6 @@ void TestNhwcConvOp(const NhwcConvOpAndTestAttributes& attributes, execution_providers.push_back(DefaultCudaExecutionProvider()); } - if (enable_rocm) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } - if (enable_dml) { execution_providers.push_back(DefaultDmlExecutionProvider()); } diff --git a/onnxruntime/test/contrib_ops/remove_padding_op_test.cc b/onnxruntime/test/contrib_ops/remove_padding_op_test.cc index d1a189de9ad4a..fe415e09fde62 100644 --- a/onnxruntime/test/contrib_ops/remove_padding_op_test.cc +++ b/onnxruntime/test/contrib_ops/remove_padding_op_test.cc @@ -22,14 +22,12 @@ static void RunRemovePadding( int total_tokens, bool use_float16 = false, const bool disable_cpu = true, - const bool disable_cuda = false, - const bool disable_rocm = true) { + const bool disable_cuda = false) { int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !disable_cuda; - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !disable_rocm; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu; - if (enable_cpu || enable_cuda || enable_rocm) { + if (enable_cpu || enable_cuda) { OpTester tester("RemovePadding", 1, onnxruntime::kMSDomain); // shape of inputs: @@ -68,12 +66,6 @@ static void RunRemovePadding( tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (enable_rocm) { - std::vector> execution_providers; - execution_providers.push_back(DefaultRocmExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (enable_cpu) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); @@ -96,15 +88,14 @@ static void RunRemovePaddingTests( bool use_float16 = false; constexpr bool disable_cpu = true; constexpr bool disable_cuda = false; - constexpr bool disable_rocm = true; RunRemovePadding(input_data, sequence_token_count_data, output_data, token_offset_data, cumulated_seq_len_data, max_token_count, batch_size, sequence_length, hidden_size, total_tokens, - use_float16, disable_cpu, disable_cuda, disable_rocm); + use_float16, disable_cpu, disable_cuda); use_float16 = true; RunRemovePadding(input_data, sequence_token_count_data, output_data, token_offset_data, cumulated_seq_len_data, max_token_count, batch_size, sequence_length, hidden_size, total_tokens, - use_float16, disable_cpu, disable_cuda, disable_rocm); + use_float16, disable_cpu, disable_cuda); } TEST(RemovePaddingTest, RemovePaddingBatch1_NoPadding) { diff --git a/onnxruntime/test/contrib_ops/restore_padding_op_test.cc b/onnxruntime/test/contrib_ops/restore_padding_op_test.cc index c8d49ce465bd6..3fef9857e4032 100644 --- a/onnxruntime/test/contrib_ops/restore_padding_op_test.cc +++ b/onnxruntime/test/contrib_ops/restore_padding_op_test.cc @@ -19,14 +19,12 @@ static void RunRestorePadding( int total_tokens, bool use_float16 = false, const bool disable_cpu = true, - const bool disable_cuda = false, - const bool disable_rocm = true) { + const bool disable_cuda = false) { int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !disable_cuda; - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !disable_rocm; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu; - if (enable_cpu || enable_cuda || enable_rocm) { + if (enable_cpu || enable_cuda) { OpTester tester("RestorePadding", 1, onnxruntime::kMSDomain); // shape of inputs: @@ -54,12 +52,6 @@ static void RunRestorePadding( tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (enable_rocm) { - std::vector> execution_providers; - execution_providers.push_back(DefaultRocmExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (enable_cpu) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); @@ -79,13 +71,12 @@ static void RunRestorePaddingTests( bool use_float16 = false; constexpr bool disable_cpu = true; constexpr bool disable_cuda = false; - constexpr bool disable_rocm = true; RunRestorePadding(input_data, output_data, token_offset_data, batch_size, sequence_length, hidden_size, total_tokens, - use_float16, disable_cpu, disable_cuda, disable_rocm); + use_float16, disable_cpu, disable_cuda); use_float16 = true; RunRestorePadding(input_data, output_data, token_offset_data, batch_size, sequence_length, hidden_size, total_tokens, - use_float16, disable_cpu, disable_cuda, disable_rocm); + use_float16, disable_cpu, disable_cuda); } TEST(RestorePaddingTest, RestorePaddingBatch1_NoPadding) { diff --git a/onnxruntime/test/contrib_ops/sampling_test.cc b/onnxruntime/test/contrib_ops/sampling_test.cc index 69789b84832e0..b9bb1004332db 100644 --- a/onnxruntime/test/contrib_ops/sampling_test.cc +++ b/onnxruntime/test/contrib_ops/sampling_test.cc @@ -18,7 +18,7 @@ namespace onnxruntime { namespace test { #if defined(__linux__) && !defined(__ANDROID__) -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(SamplingTest, Gpt2Sampling_GPU) { std::vector input_ids{ 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, @@ -73,10 +73,6 @@ TEST(SamplingTest, Gpt2Sampling_GPU) { OrtCUDAProviderOptionsV2 cuda_options; cuda_options.use_tf32 = false; session_options.AppendExecutionProvider_CUDA_V2(cuda_options); -#else // USE_ROCM - OrtROCMProviderOptions rocm_options; - // TODO - verify the default settings - session_options.AppendExecutionProvider_ROCM(rocm_options); #endif Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_sampling.onnx"), session_options); diff --git a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc index 3e8870892b7c9..c140f18cb9fe3 100644 --- a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc @@ -114,21 +114,16 @@ TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) { int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); std::array channels_last_values = {-1, 1}; for (const int channels_last : channels_last_values) { - if (enable_cuda || enable_rocm) { + if (enable_cuda) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); } - if (enable_rocm && channels_last != 0) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } - // Don't run the test if no providers are supported if (execution_providers.empty()) { continue; @@ -235,7 +230,6 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); - bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); std::array has_add_out_values = {true, false}; std::array skip_dims = {2, 4}; @@ -243,16 +237,12 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { constexpr int channels_last = 1; for (const int skip_dim : skip_dims) { for (const bool has_add_out : has_add_out_values) { - if (enable_cuda || enable_rocm) { + if (enable_cuda) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); } - if (enable_rocm && channels_last != 0) { - execution_providers.push_back(DefaultRocmExecutionProvider()); - } - // Don't run the test if no providers are supported if (execution_providers.empty()) { continue; diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index a1856b70f711f..85538efbefd28 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -60,7 +60,6 @@ static void RunOneTest( std::string op_type = simplified ? "SkipSimplifiedLayerNormalization" : "SkipLayerNormalization"; - auto rocm_ep = DefaultRocmExecutionProvider(); auto dml_ep = DefaultDmlExecutionProvider(); auto cpu_ep = DefaultCpuExecutionProvider(); auto webgpu_ep = DefaultWebGpuExecutionProvider(); @@ -147,7 +146,6 @@ static void RunOneTest( test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) || dml_ep != nullptr || - rocm_ep != nullptr || webgpu_ep != nullptr) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); test.AddInput("input", input_dims, ToFloat16(input_data)); @@ -186,8 +184,6 @@ static void RunOneTest( execution_providers.push_back(DefaultWebGpuExecutionProvider()); } else if (dml_ep != nullptr) { execution_providers.push_back(DefaultDmlExecutionProvider()); - } else if (rocm_ep != nullptr) { - execution_providers.push_back(DefaultRocmExecutionProvider()); } else { if (strict) { Ort::CUDAProviderOptions cuda_options; @@ -877,7 +873,6 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) { simplified); } -#if !defined(USE_ROCM) TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) { int batch_size = 2; int sequence_length = 2; @@ -987,7 +982,6 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_Batch_Size_1) { broadcast_skip, no_batch_size); } -#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 30595d5ce97b2..b91a054ab691b 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -4,10 +4,13 @@ #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include +#include #include "gsl/gsl" #include "gtest/gtest.h" #include "core/common/logging/sinks/file_sink.h" +#include "core/framework/kernel_def_builder.h" +#include "core/framework/op_kernel.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/optimizer/graph_optimizer_registry.h" @@ -30,12 +33,22 @@ struct ApiPtrs { static void CheckStringInFile(const PathString& filename, const std::string& look_for) { std::ifstream ifs{filename}; + ASSERT_TRUE(ifs); std::string content(std::istreambuf_iterator{ifs}, std::istreambuf_iterator{}); EXPECT_NE(content.find(look_for), std::string::npos); } +static void CheckFileIsEmpty(const PathString& filename) { + std::ifstream ifs{filename}; + ASSERT_TRUE(ifs); + std::string content(std::istreambuf_iterator{ifs}, + std::istreambuf_iterator{}); + + EXPECT_TRUE(content.empty()); +} + // Normally, a plugin EP would be implemented in a separate library. // The `test_plugin_ep` namespace contains a local implementation intended for unit testing. namespace test_plugin_ep { @@ -121,14 +134,25 @@ MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = { *static_cast(ort_session_options), g_test_ort_ep_factory, ep_devices, + /*kernel_registry*/ nullptr, logging_manager.DefaultLogger()); auto result = MakeTestOrtEpResult{std::move(ep), ort_ep_raw}; return result; } +using LookUpKernelFunc = std::function; + class MockKernelLookup : public IExecutionProvider::IKernelLookup { - const KernelCreateInfo* LookUpKernel(const Node& /*node*/) const override { return nullptr; } + public: + explicit MockKernelLookup(LookUpKernelFunc lookup = nullptr) : lookup_{lookup} {} + + const KernelCreateInfo* LookUpKernel(const Node& node) const override { + return lookup_ != nullptr ? lookup_(node) : nullptr; + } + + private: + LookUpKernelFunc lookup_ = nullptr; }; } // namespace test_plugin_ep @@ -435,10 +459,23 @@ static OrtStatus* ORT_API_CALL GetCapabilityTakeSingleNode(OrtEp* this_ptr, cons return st; } - // Take only the first node using EpGraphSupportInfo_AddSingleNode(). - if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddSingleNode(graph_support_info, nodes[0]); - st != nullptr) { - return st; + // Take only the first node that has a registered kernel for this EP. + for (const OrtNode* node : nodes) { + const OrtKernelDef* kernel_def = nullptr; + OrtStatus* status = this_ep->ep_api->EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def); + + if (status != nullptr) { + return status; + } + + if (kernel_def != nullptr) { + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddSingleNode(graph_support_info, node); + st != nullptr) { + return st; + } + + break; + } } return nullptr; @@ -454,7 +491,8 @@ TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { auto run_test = [&log_file](IExecutionProvider& ep, const std::unordered_set& nodes_for_other_ep, const std::unordered_set& nodes_for_this_ep, - const char* expected_log_string) { + const char* expected_log_string, + test_plugin_ep::LookUpKernelFunc lookup_kernel_func = nullptr) { std::shared_ptr model; ASSERT_NO_FATAL_FAILURE(LoadModelAndAssignNodesToEp(ORT_TSTR("testdata/add_mul_add.onnx"), "OtherEp", nodes_for_other_ep, model)); @@ -471,7 +509,7 @@ TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { GraphViewer graph_viewer(model->MainGraph()); auto compute_capabilities = ep.GetCapability(graph_viewer, - test_plugin_ep::MockKernelLookup{}, + test_plugin_ep::MockKernelLookup(lookup_kernel_func), GraphOptimizerRegistry(nullptr, nullptr, file_logger.get()), nullptr); @@ -489,7 +527,12 @@ TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { } ASSERT_TRUE(std::filesystem::exists(log_file)); - EXPECT_NO_FATAL_FAILURE(CheckStringInFile(log_file, expected_log_string)); + + if (expected_log_string != nullptr) { + EXPECT_NO_FATAL_FAILURE(CheckStringInFile(log_file, expected_log_string)); + } else { + EXPECT_NO_FATAL_FAILURE(CheckFileIsEmpty(log_file)); + } }; constexpr std::array node_names = {"add_0", "mul_0", "add_1"}; @@ -536,6 +579,19 @@ TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + // Build dummy kernel definition for an Add node. Retrieved by OrtEp using EpGraphSupportInfo_LookUpKernel(). + KernelDefBuilder builder; + builder.SetName("Add").SinceVersion(1).Provider("TestOrtEp"); + auto add_kernel_create_info = std::make_unique(builder.Build(), nullptr); + + auto mock_kernel_lookup_fn = [&add_kernel_create_info](const Node& node) -> const KernelCreateInfo* { + // Only return a result for an Add node. + if (add_kernel_create_info->kernel_def->OpName() == node.OpType()) { + return add_kernel_create_info.get(); + } + return nullptr; + }; + // Load a model and assign the first Add node to another EP named 'OtherEp'. // The plugin EP will try to take only the first Add node with a single call to EpGraphSupportInfo_AddSingleNode. // IExecutionProvider::GetCapability() will return an empty result and log a warning. @@ -543,9 +599,153 @@ TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { nodes_for_other_ep = std::unordered_set{"add_0"}; nodes_for_this_ep = std::unordered_set{}; run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, - "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'", + mock_kernel_lookup_fn); + + // Load a model and assign the last Add node to another EP named 'OtherEp'. + // The plugin EP will try to take only the first Add node with a single call to EpGraphSupportInfo_AddSingleNode. + // IExecutionProvider::GetCapability() will return a single capability and will not log warnings. + ort_ep->GetCapability = GetCapabilityTakeSingleNode; + nodes_for_other_ep = std::unordered_set{"add_1"}; + nodes_for_this_ep = std::unordered_set{"add_0"}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + /*expected_log_string*/ nullptr, mock_kernel_lookup_fn); std::filesystem::remove(log_file); } +// Test plugin EP's use of the EpGraphSupportInfo_LookUpKernel API. +TEST(PluginExecutionProviderTest, GetCapability_LookUpKernel) { + // Helper that calls IExecutionProvider::GetCapability and checks expected results. + auto run_test = [](IExecutionProvider& ep, const std::unordered_set& expected_claimed_nodes, + test_plugin_ep::LookUpKernelFunc lookup_kernel_func) { + const logging::Logger& logger = DefaultLoggingManager().DefaultLogger(); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ORT_TSTR("testdata/add_mul_add.onnx"), model, nullptr, + DefaultLoggingManager().DefaultLogger())); + + { + ep.SetLogger(&logger); + + GraphViewer graph_viewer(model->MainGraph()); + auto compute_capabilities = ep.GetCapability(graph_viewer, + test_plugin_ep::MockKernelLookup(lookup_kernel_func), + GraphOptimizerRegistry(nullptr, nullptr, &logger), + nullptr); + + ASSERT_EQ(compute_capabilities.size(), expected_claimed_nodes.empty() ? 0 : 1); + + if (compute_capabilities.size() == 1) { + ASSERT_EQ(compute_capabilities[0]->sub_graph->nodes.size(), expected_claimed_nodes.size()); + + for (NodeIndex node_index : compute_capabilities[0]->sub_graph->nodes) { + const Node* node = graph_viewer.GetNode(node_index); + ASSERT_NE(node, nullptr); + EXPECT_EQ(expected_claimed_nodes.count(node->Name()), 1); + } + } + } + }; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); + + // Build dummy kernel lookup function that always returns null. Used by OrtEp using EpGraphSupportInfo_LookUpKernel(). + // Expect that the plugin EP will not claim any nodes because no valid kernel definitions are registered. + { + auto mock_kernel_lookup_fn = [](const Node& /*node*/) -> const KernelCreateInfo* { + return nullptr; + }; + + ort_ep->GetCapability = GetCapabilityTakeSingleNode; + std::unordered_set expected_claimed_nodes; // Empty. No nodes should be claimed. + run_test(*ep, expected_claimed_nodes, mock_kernel_lookup_fn); + } + + // Test a kernel lookup function that only returns a kernel definition for a Mul node. + // Expect that plugin EP will take only the Mul node. + { + KernelDefBuilder builder; + builder.SetName("Mul").SinceVersion(1).Provider("TestOrtEp"); + auto kernel_create_info = std::make_unique(builder.Build(), nullptr); + + auto mock_kernel_lookup_fn = [&kernel_create_info](const Node& node) -> const KernelCreateInfo* { + if (kernel_create_info->kernel_def->OpName() == node.OpType()) { + return kernel_create_info.get(); + } + return nullptr; + }; + + ort_ep->GetCapability = GetCapabilityTakeSingleNode; + std::unordered_set expected_claimed_nodes = {"mul_0"}; + run_test(*ep, expected_claimed_nodes, mock_kernel_lookup_fn); + } +} + +TEST(PluginExecutionProviderTest, KernelDefCxxApis) { + auto check_kernel_def = [&](const KernelDef& expected, Ort::ConstKernelDef actual) -> void { + EXPECT_EQ(expected.OpName(), actual.GetOperatorType()); + EXPECT_EQ(expected.Domain(), actual.GetDomain()); + + auto [expected_start, expected_end] = expected.SinceVersion(); + auto [actual_start, actual_end] = actual.GetSinceVersion(); + + EXPECT_EQ(expected_start, actual_start); + + if (expected_end != actual_end) { + // Instead of using INT_MAX, the public API just sets the start version equal to the end version. + EXPECT_EQ(actual_start, actual_end); + EXPECT_EQ(expected_end, std::numeric_limits::max()); + } + + EXPECT_EQ(expected.Provider(), actual.GetExecutionProvider()); + EXPECT_EQ(expected.InputMemoryType(0), actual.GetInputMemType(0)); + EXPECT_EQ(expected.InputMemoryType(1), actual.GetInputMemType(1)); + EXPECT_EQ(expected.OutputMemoryType(1), actual.GetOutputMemType(1)); + }; + + // Check that C++ APIs for Ort::KernelDef return the expected values. + { + KernelDefBuilder builder; + std::unique_ptr expected_def = builder.SetName("Mul") + .SetDomain("TestDomain") + .SinceVersion(3, 13) + .Provider("TestOrtEp") + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .OutputMemoryType(OrtMemTypeCPUOutput, 1) + .Build(); + + Ort::KernelDefBuilder api_builder; + Ort::KernelDef actual_def = api_builder.SetOperatorType("Mul") + .SetDomain("TestDomain") + .SetSinceVersion(3, 13) + .SetExecutionProvider("TestOrtEp") + .SetInputMemType(0, OrtMemTypeCPUInput) + .SetInputMemType(1, OrtMemTypeCPUInput) + .SetOutputMemType(1, OrtMemTypeCPUOutput) + .Build(); + + EXPECT_NO_FATAL_FAILURE(check_kernel_def(*expected_def, actual_def.GetConst())); + } + + // SinceVersion with no explicit end (defaults to start version) + { + KernelDefBuilder builder; + std::unique_ptr expected_def = builder.SetName("Mul") + .SetDomain("TestDomain") + .Provider("TestOrtEp") + .SinceVersion(3) // end should default to INT_MAX (means not set) + .Build(); + + Ort::KernelDefBuilder api_builder; + Ort::KernelDef actual_def = api_builder.SetOperatorType("Mul") + .SetDomain("TestDomain") + .SetExecutionProvider("TestOrtEp") + .SetSinceVersion(3, 3) // start == end (only one version supported) + .Build(); + EXPECT_NO_FATAL_FAILURE(check_kernel_def(*expected_def, actual_def.GetConst())); + } +} + } // namespace onnxruntime::test diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index aca345fccdc01..8b66009c0c72f 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -40,10 +40,6 @@ #ifdef USE_TENSORRT #include "core/providers/tensorrt/tensorrt_provider_options.h" #endif -#ifdef USE_ROCM -#include "core/providers/rocm/rocm_provider_factory.h" -#include "core/providers/rocm/gpu_data_transfer.h" -#endif #include "core/session/allocator_adapters.h" #include "core/session/environment.h" #include "core/session/IOBinding.h" @@ -77,9 +73,6 @@ namespace onnxruntime { #ifdef USE_CUDA ProviderInfo_CUDA& GetProviderInfo_CUDA(); #endif -#ifdef USE_ROCM -ProviderInfo_ROCM& GetProviderInfo_ROCM(); -#endif class FuseAdd : public OpKernel { public: @@ -217,7 +210,7 @@ static void CreateMatMulModel(std::unique_ptr& p_model, Prov if (provider_type == kCpuExecutionProvider) { node.SetExecutionProviderType(provider_type); } else { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) +#if defined(USE_CUDA) || defined(USE_WEBGPU) node.SetExecutionProviderType(provider_type); #endif } @@ -307,7 +300,7 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, // And it can't be used for copying buffer to buffer since the target buffer is still in mapped state. OrtMemoryInfo mem_info(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)); gpu_alloc = session_object.GetAllocator(mem_info); - } else if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider) { + } else if (allocation_provider == kCudaExecutionProvider) { gpu_alloc = gpu_provider->CreatePreferredAllocators()[0]; } if (enable_graph_capture) { @@ -367,7 +360,7 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, if (is_preallocate_output_vec) { if (allocation_provider == kCpuExecutionProvider) { AllocateMLValue(cpu_alloc, expected_output_dims, &output_ml_value); - } else if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider || allocation_provider == kWebGpuExecutionProvider) { + } else if (allocation_provider == kCudaExecutionProvider || allocation_provider == kWebGpuExecutionProvider) { AllocateMLValue(gpu_alloc, expected_output_dims, &output_ml_value); } else { ORT_THROW("Unsupported provider"); @@ -390,9 +383,9 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, // Now run ASSERT_STATUS_OK(session_object.Run(run_options, *io_binding)); - if ((is_preallocate_output_vec && (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider || allocation_provider == kWebGpuExecutionProvider)) || + if ((is_preallocate_output_vec && (allocation_provider == kCudaExecutionProvider || allocation_provider == kWebGpuExecutionProvider)) || (output_device && output_device->Type() == OrtDevice::GPU)) { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) +#if defined(USE_CUDA) || defined(USE_WEBGPU) // in this case we need to copy the tensor from cuda to cpu std::vector& outputs = io_binding->GetOutputs(); ASSERT_EQ(1u, outputs.size()); @@ -403,9 +396,6 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, #ifdef USE_CUDA st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, cpu_tensor); #endif -#ifdef USE_ROCM - st = GetProviderInfo_ROCM().CreateGPUDataTransfer()->CopyTensor(rtensor, cpu_tensor); -#endif #ifdef USE_WEBGPU st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, cpu_tensor); #endif @@ -415,7 +405,7 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, VerifyOutputs({ml_value}, expected_output_dims, expected_values_mul_y); #endif } else { - if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider || allocation_provider == kWebGpuExecutionProvider) { + if (allocation_provider == kCudaExecutionProvider || allocation_provider == kWebGpuExecutionProvider) { ASSERT_STATUS_OK(gpu_provider->Sync()); } VerifyOutputs(io_binding->GetOutputs(), expected_output_dims, expected_values_mul_y); @@ -637,9 +627,6 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions) { InferenceSession session_object(so, GetEnvironment()); #ifdef USE_CUDA ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); -#endif -#ifdef USE_ROCM - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider())); #endif ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); ASSERT_STATUS_OK(session_object.Initialize()); @@ -676,7 +663,7 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions) { } } -#if (defined(USE_CUDA) && defined(ENABLE_CUDA_PROFILING)) || (defined(USE_ROCM) && defined(ENABLE_ROCM_PROFILING)) +#if (defined(USE_CUDA) && defined(ENABLE_CUDA_PROFILING)) ASSERT_TRUE(has_kernel_info); #endif } @@ -692,9 +679,6 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { #ifdef USE_CUDA ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); #endif -#ifdef USE_ROCM - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider())); -#endif #ifdef USE_WEBGPU ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultWebGpuExecutionProvider())); #endif @@ -731,10 +715,6 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { has_api_info = has_api_info || lines[i].find("Api") != std::string::npos && lines[i].find("cudaLaunch") != std::string::npos; #endif -#ifdef USE_ROCM - has_api_info = has_api_info || lines[i].find("Api") != std::string::npos && - lines[i].find("hipLaunch") != std::string::npos; -#endif #ifdef USE_WEBGPU has_api_info = has_api_info || lines[i].find("Api") != std::string::npos; #endif @@ -742,7 +722,7 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { } // Note that the apple device is a paravirtual device which may not support webgpu timestamp query. So skip the check on it. -#if (defined(USE_ROCM) && defined(ENABLE_ROCM_PROFILING)) || (defined(USE_WEBGPU) && !defined(__APPLE__)) +#if (defined(USE_WEBGPU) && !defined(__APPLE__)) ASSERT_TRUE(has_api_info); #endif } @@ -1041,7 +1021,7 @@ static void TestBindHelper(const std::string& log_str, InferenceSession session_object{so, GetEnvironment()}; IExecutionProvider* gpu_provider{}; - if (bind_provider_type == kCudaExecutionProvider || bind_provider_type == kRocmExecutionProvider || bind_provider_type == kWebGpuExecutionProvider) { + if (bind_provider_type == kCudaExecutionProvider || bind_provider_type == kWebGpuExecutionProvider) { #ifdef USE_CUDA { auto provider = DefaultCudaExecutionProvider(); @@ -1049,13 +1029,6 @@ static void TestBindHelper(const std::string& log_str, ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(provider))); } #endif -#ifdef USE_ROCM - { - auto provider = DefaultRocmExecutionProvider(); - gpu_provider = provider.get(); - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(provider))); - } -#endif #ifdef USE_WEBGPU { ConfigOptions config_options{}; @@ -1176,11 +1149,9 @@ TEST(InferenceSessionTests, InvalidInputTypeOfTensorElement) { ASSERT_TRUE(!st.IsOK()); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) +#if defined(USE_CUDA) || defined(USE_WEBGPU) #if USE_CUDA constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; -#elif USE_ROCM -constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider; #elif USE_WEBGPU constexpr const char* kGpuExecutionProvider = kWebGpuExecutionProvider; #endif @@ -1670,8 +1641,6 @@ TEST(InferenceSessionTests, Test3LayerNestedSubgraph) { ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultTensorrtExecutionProvider())); #elif USE_CUDA ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); -#elif USE_ROCM - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider())); #endif status = session_object.Load(model_file_name); @@ -1822,8 +1791,6 @@ TEST(InferenceSessionTests, Test2LayerNestedSubgraph) { ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultTensorrtExecutionProvider())); #elif USE_CUDA ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); -#elif USE_ROCM - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider())); #endif status = session_object.Load(model_file_name); diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index 37c3825101ba4..82003c5eea31f 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -16,6 +16,8 @@ using namespace ONNX_NAMESPACE; +extern std::unique_ptr ort_env; + namespace onnxruntime { namespace test { @@ -76,6 +78,103 @@ TEST_F(ShapeInferenceTest, BasicTest) { CheckShapeEquality(InputShape(node), OutputShape(node)); } +TEST(ShapeInferenceV2Test, PartialDataPropagationTest) { + { + // Model #1 + // This model contains "Shape" and "Reshape" operators. + auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes.onnx"); + + Ort::SessionOptions session_options{}; + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + session_options.AddFreeDimensionOverrideByName("batch", 1); + session_options.AddFreeDimensionOverrideByName("width", 64); + session_options.AddFreeDimensionOverrideByName("height", 64); + + // Even though all graph optimizations are disabled, the free dimension override is still enabled by default. + // The shape of graph's output should be correctly inferred by shape inference and data propagation. + Ort::Session session(*ort_env, model_path, session_options); + + // This graph only has one output + ORT_ENFORCE(session.GetOutputCount() == 1); + + Ort::TypeInfo type_info = session.GetOutputTypeInfo(0); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + std::vector output_shape = tensor_info.GetShape(); + EXPECT_TRUE(output_shape.size() == 4) << "The output shape should have 4 dimensions"; + EXPECT_TRUE(output_shape[0] == 1) << "The first dimension should have 1 as value"; + EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should have 3 as value"; + EXPECT_TRUE(output_shape[2] == 64) << "The second dimension should have 64 as value"; + EXPECT_TRUE(output_shape[3] == 64) << "The second dimension should have 64 as value"; + } + + { + // Model #2 + // This model contains "Shape", "Reshape", "Gather" and "Unsqueeze" operators. + auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx"); + + Ort::SessionOptions session_options{}; + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + session_options.AddFreeDimensionOverrideByName("batch", 1); + session_options.AddFreeDimensionOverrideByName("width", 64); + session_options.AddFreeDimensionOverrideByName("height", 64); + + // Even though all graph optimizations are disabled, the free dimension override is still enabled by default. + // The shape of graph's output should be correctly inferred by shape inference and data propagation. + Ort::Session session(*ort_env, model_path, session_options); + + // This graph only has one output + ORT_ENFORCE(session.GetOutputCount() == 1); + + Ort::TypeInfo type_info = session.GetOutputTypeInfo(0); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + std::vector output_shape = tensor_info.GetShape(); + EXPECT_TRUE(output_shape.size() == 3) << "The output shape should have 3 dimensions"; + EXPECT_TRUE(output_shape[0] == 1) << "The first dimension should have 1 as value"; + EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should have 3 as value"; + EXPECT_TRUE(output_shape[2] == 4096) << "The second dimension should have 4096 as value"; + } + + { + // Model #3 + // This model extends model #2 and appends Unsqueeze -> Unsqueeze -> Squeeze -> Squeeze -> Reshape to the end. + auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes_v3.onnx"); + + Ort::SessionOptions session_options{}; + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + session_options.AddFreeDimensionOverrideByName("batch", 1); + session_options.AddFreeDimensionOverrideByName("width", 64); + session_options.AddFreeDimensionOverrideByName("height", 64); + + // Even though all graph optimizations are disabled, the free dimension override is still enabled by default. + // The shape of graph's output should be correctly inferred by shape inference and data propagation. + Ort::Session session(*ort_env, model_path, session_options); + + // This graph only has one output + ORT_ENFORCE(session.GetOutputCount() == 1); + + Ort::TypeInfo type_info = session.GetOutputTypeInfo(0); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + std::vector output_shape = tensor_info.GetShape(); + EXPECT_TRUE(output_shape.size() == 3) << "The output shape should have 3 dimensions"; + EXPECT_TRUE(output_shape[0] == 1) << "The first dimension should have 1 as value"; + EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should have 3 as value"; + EXPECT_TRUE(output_shape[2] == 4096) << "The second dimension should have 4096 as value"; + } + + { + // Model #4 + // This model contains Shape, Reshape, Squeeze, Range, ReduceSum. + // It's from SoftmaxGrad_DefaultAxis test. + auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes_v4.onnx"); + + Ort::SessionOptions session_options{}; + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + + // Make sure it can load the model and run shape inference without errors. + Ort::Session session(*ort_env, model_path, session_options); + } +} + namespace { struct MyCustomKernelWithOptionalInput { MyCustomKernelWithOptionalInput(const OrtKernelInfo* info) { diff --git a/onnxruntime/test/internal_testing_ep/internal_testing_tests.cc b/onnxruntime/test/internal_testing_ep/internal_testing_tests.cc index ee3824a5ca2f2..74a812062875a 100644 --- a/onnxruntime/test/internal_testing_ep/internal_testing_tests.cc +++ b/onnxruntime/test/internal_testing_ep/internal_testing_tests.cc @@ -161,7 +161,7 @@ TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) { // the internal NHWC operators are only included as part of contrib ops currently. as the EP requests the NHWC // version of the ONNX operator when matching a static kernel, those are required. -#if !defined(DISABLE_CONTRIB_OPS) && !defined(USE_ROCM) +#if !defined(DISABLE_CONTRIB_OPS) TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) { const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "transform/fusion/conv_relu_opset12.onnx"; diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 6a3f2f974b9f5..4d80cb704748c 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -2813,8 +2813,15 @@ TEST_F(GraphTest, ShapeInferenceAfterInitializerExternalization) { ASSERT_TRUE(graph.GetInitializedTensor("split_sizes", initializer_after)); ASSERT_NE(initializer_after, nullptr); // Debug: verify it was externalized + ASSERT_FALSE(utils::HasExternalDataInMemory(*initializer_after)) + << "We no longer externalize data in the Graph constructor."; + + // Now externalize explicitly to trigger the bug scenario + ASSERT_STATUS_OK(graph.ConvertInitializersIntoOrtValues()); + ASSERT_TRUE(graph.GetInitializedTensor("split_sizes", initializer_after)); + ASSERT_NE(initializer_after, nullptr); ASSERT_TRUE(utils::HasExternalDataInMemory(*initializer_after)) - << "Initializer was not externalized to in-memory external data"; + << "The initializer should externalize now"; // Mark the graph as needing resolve to force shape inference to run again graph.SetGraphResolveNeeded(); diff --git a/onnxruntime/test/ir/utils_test.cc b/onnxruntime/test/ir/utils_test.cc index e9744ccacbdd5..ae212a726cf4c 100644 --- a/onnxruntime/test/ir/utils_test.cc +++ b/onnxruntime/test/ir/utils_test.cc @@ -7,6 +7,7 @@ #include "core/graph/model.h" #include "test/test_environment.h" +#include "test/util/include/asserts.h" using ONNX_NAMESPACE::Utils::DataTypeUtils; using namespace ONNX_NAMESPACE; @@ -178,8 +179,7 @@ static void CreateNodeRemovalGraph(Model& model, bool removal_allowed, bool test if_node.AddAttribute("then_branch", then_branch); if_node.AddAttribute("else_branch", else_branch); - auto status = graph.Resolve(); - ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(graph.Resolve()); } static void CheckNodeRemovalSubgraphUpdate(const std::string& new_name, const Graph& subgraph) { diff --git a/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp index bebff37ad8460..83f5b7f106d3e 100644 --- a/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp @@ -4,11 +4,8 @@ // SPDX-License-Identifier: MIT // -// Currently this test only applies to KleidiAI Guard against it running in any other situation -#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) - +#include "mlas.h" #include "test_util.h" -#include "core/mlas/lib/mlasi.h" // for MLAS_CPUIDINFO class MlasDynamicQgemmTest { private: @@ -20,11 +17,6 @@ class MlasDynamicQgemmTest { public: void Test(size_t M, size_t N, size_t K, size_t BatchSize) { - // Currently, MlasDynamicQGemmBatch() and associated functions require SME2 or else they are no-ops. - if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) { - GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME2 but it was not detected. Skipping test."; - } - // Setup buffers for holding various data float* A = buffer_a.GetBuffer(M * K * BatchSize); @@ -167,6 +159,10 @@ class DynamicQgemmExecuteTest : public MlasTestFixture { }; static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + // Only register tests if MlasDynamicQGemmBatch() has an implementation available. + if (!MlasIsDynamicQGemmAvailable()) { + return size_t{0}; + } + return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); }); -#endif diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 9e69156efefa1..8446f88639436 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -56,7 +56,7 @@ void usage() { "\t-v: verbose\n" "\t-n [test_case_name]: Specifies a single test case to run.\n" "\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', 'vsinpu'" - "'openvino', 'rocm', 'migraphx', 'acl', 'armnn', 'xnnpack', 'webgpu', 'nnapi', 'qnn', 'snpe' or 'coreml'. " + "'openvino', 'migraphx', 'acl', 'armnn', 'xnnpack', 'webgpu', 'nnapi', 'qnn', 'snpe' or 'coreml'. " "Default: 'cpu'.\n" "\t-p: Pause after launch, can attach debugger and continue\n" "\t-x: Use parallel executor, default (without -x): sequential executor.\n" @@ -228,7 +228,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) { bool enable_dml = false; bool enable_acl = false; bool enable_armnn = false; - bool enable_rocm = false; bool enable_migraphx = false; bool enable_webgpu = false; bool enable_xnnpack = false; @@ -319,8 +318,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) { enable_acl = true; } else if (!CompareCString(optarg, ORT_TSTR("armnn"))) { enable_armnn = true; - } else if (!CompareCString(optarg, ORT_TSTR("rocm"))) { - enable_rocm = true; } else if (!CompareCString(optarg, ORT_TSTR("migraphx"))) { enable_migraphx = true; } else if (!CompareCString(optarg, ORT_TSTR("webgpu"))) { @@ -746,17 +743,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else fprintf(stderr, "ArmNN is not supported in this build\n"); return -1; -#endif - } - if (enable_rocm) { -#ifdef USE_ROCM - OrtROCMProviderOptions rocm_options; - rocm_options.do_copy_in_default_stream = true; - // TODO: Support arena configuration for users of test runner - sf.AppendExecutionProvider_ROCM(rocm_options); -#else - fprintf(stderr, "ROCM is not supported in this build"); - return -1; #endif } if (enable_migraphx) { diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc index 333c1edf8ffab..08c7a0700030f 100644 --- a/onnxruntime/test/optimizer/compute_optimizer_test.cc +++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc @@ -195,8 +195,6 @@ TEST(ComputeOptimizerTests, GatherND_E2E) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; const std::vector output_names{"output", "gather_output"}; @@ -300,8 +298,6 @@ TEST(ComputeOptimizerTests, GatherMatMul_ScalarSlicingOnBatchDim) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; @@ -406,8 +402,6 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnBatchDim) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; @@ -504,8 +498,6 @@ TEST(ComputeOptimizerTests, GatherMatMul_ScalarSlicingOnLastDim) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; @@ -602,8 +594,6 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnLastDim) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; @@ -702,8 +692,6 @@ TEST(ComputeOptimizerTests, GatherMatMul_ScalarSlicingOnSecondLastDim) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; @@ -801,8 +789,6 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnSecondLastDim) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; @@ -1232,8 +1218,6 @@ TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnBatchDim) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; @@ -1327,8 +1311,6 @@ TEST(ComputeOptimizerTests, GatherReshape_SlicingOnBatchDim) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; @@ -1420,8 +1402,6 @@ TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnSeqlenDim) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; @@ -1514,8 +1494,6 @@ TEST(ComputeOptimizerTests, GatherReshape_SlicingOnSeqlenDim) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; @@ -1608,8 +1586,6 @@ TEST(ComputeOptimizerTests, GatherReshape_SlicingOnSeqlenDim2) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; @@ -1781,8 +1757,6 @@ TEST(ComputeOptimizerTests, GatherRobertaE2E) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; @@ -3072,8 +3046,6 @@ TEST(ComputeOptimizerTests, ReshapeMlmBertE2E) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 9f0a2ad2de1c2..70e84733fa869 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -5807,12 +5807,6 @@ TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_GpuOnly) { tester.TestNoFusionOccurs(); } -TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_Simple_Rocm) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_softmax_fusion_simple.onnx"; - BiasSoftmaxFusionTester tester(model_uri, logger_.get(), kRocmExecutionProvider); - tester.TestFusionOccurs(1, true); -} - TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_Simple_Cuda) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_softmax_fusion_simple.onnx"; BiasSoftmaxFusionTester tester(model_uri, logger_.get()); @@ -6515,7 +6509,7 @@ TEST_F(GraphTransformationTests, MatMulScaleFusionWithScaleInput) { }); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST_F(GraphTransformationTests, IsInfReduceSum_Test) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/isinf_reducesum.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc b/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc index adc173456a7db..be9e5bee4df5d 100644 --- a/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc +++ b/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc @@ -30,7 +30,7 @@ TEST(RuleBasedGraphTransformerTest, TestCompatibleProviders) { Graph& graph = model->MainGraph(); // Create rule based transformer with a dummy rewrite rule and register it with Cuda as compatible provider - InlinedHashSet compatible_provider{onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider}; + InlinedHashSet compatible_provider{onnxruntime::kCudaExecutionProvider}; auto dummy_rule = std::make_unique("DummyRule"); const auto* dummy_rule_ptr = dummy_rule.get(); diff --git a/onnxruntime/test/optimizer/test_optimizer_utils.cc b/onnxruntime/test/optimizer/test_optimizer_utils.cc index 40065c2fc7006..baba334d017fe 100644 --- a/onnxruntime/test/optimizer/test_optimizer_utils.cc +++ b/onnxruntime/test/optimizer/test_optimizer_utils.cc @@ -65,8 +65,6 @@ void RunModelWithData(const PathString& model_uri, const std::string session_log execution_provider = DefaultCpuExecutionProvider(); else if (provider_type == onnxruntime::kCudaExecutionProvider) execution_provider = DefaultCudaExecutionProvider(); - else if (provider_type == onnxruntime::kRocmExecutionProvider) - execution_provider = DefaultRocmExecutionProvider(); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); Status st; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 2c9377d48f0c4..c27700166e584 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -42,7 +42,7 @@ ABSL_FLAG(std::string, F, "", "[Usage]: -f \"dimension_denotation1:override_value1\" -f \"dimension_denotation2:override_value2\" ... or" " -f \"dimension_denotation1:override_value1 dimension_denotation2 : override_value2... \". Override value must > 0."); ABSL_FLAG(std::string, m, "duration", "Specifies the test mode. Value could be 'duration' or 'times'."); -ABSL_FLAG(std::string, e, "cpu", "Specifies the provider 'cpu','cuda','dnnl','tensorrt', 'nvtensorrtrtx', 'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack', 'vitisai' or 'webgpu'."); +ABSL_FLAG(std::string, e, "cpu", "Specifies the provider 'cpu','cuda','dnnl','tensorrt', 'nvtensorrtrtx', 'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'migraphx', 'xnnpack', 'vitisai' or 'webgpu'."); ABSL_FLAG(size_t, r, DefaultPerformanceTestConfig().run_config.repeated_times, "Specifies the repeated times if running in 'times' test mode."); ABSL_FLAG(size_t, t, DefaultPerformanceTestConfig().run_config.duration_in_seconds, "Specifies the seconds to run for 'duration' mode."); ABSL_FLAG(std::string, p, "", "Specifies the profile name to enable profiling and dump the profile data to the file."); @@ -325,8 +325,6 @@ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int a test_config.machine_config.provider_type_name = onnxruntime::kAclExecutionProvider; } else if (ep == "armnn") { test_config.machine_config.provider_type_name = onnxruntime::kArmNNExecutionProvider; - } else if (ep == "rocm") { - test_config.machine_config.provider_type_name = onnxruntime::kRocmExecutionProvider; } else if (ep == "migraphx") { test_config.machine_config.provider_type_name = onnxruntime::kMIGraphXExecutionProvider; } else if (ep == "xnnpack") { diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index cb40a9beafeee..3468e2e55c7b6 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -661,16 +661,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); performance_test_config.run_config.enable_cpu_mem_arena ? 1 : 0)); #else ORT_THROW("ArmNN is not supported in this build\n"); -#endif - } else if (provider_name_ == onnxruntime::kRocmExecutionProvider) { -#ifdef USE_ROCM - OrtROCMProviderOptions rocm_options; - rocm_options.miopen_conv_exhaustive_search = performance_test_config.run_config.cudnn_conv_algo; - rocm_options.do_copy_in_default_stream = !performance_test_config.run_config.do_cuda_copy_in_separate_stream; - // TODO: Support arena configuration for users of perf test - session_options.AppendExecutionProvider_ROCM(rocm_options); -#else - ORT_THROW("ROCM is not supported in this build\n"); #endif } else if (provider_name_ == onnxruntime::kMIGraphXExecutionProvider) { #ifdef USE_MIGRAPHX diff --git a/onnxruntime/test/providers/compare_provider_test_utils.cc b/onnxruntime/test/providers/compare_provider_test_utils.cc index 386a5656d8a01..63120143870d4 100644 --- a/onnxruntime/test/providers/compare_provider_test_utils.cc +++ b/onnxruntime/test/providers/compare_provider_test_utils.cc @@ -32,8 +32,6 @@ std::unique_ptr GetExecutionProvider(const std::string& prov execution_provider = DefaultNnapiExecutionProvider(); else if (provider_type == onnxruntime::kAclExecutionProvider) execution_provider = DefaultAclExecutionProvider(); - else if (provider_type == onnxruntime::kRocmExecutionProvider) - execution_provider = DefaultRocmExecutionProvider(); else if (provider_type == onnxruntime::kDmlExecutionProvider) execution_provider = DefaultDmlExecutionProvider(); else if (provider_type == onnxruntime::kWebGpuExecutionProvider) diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index 11a3d67a3e13e..d711e050fb913 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -172,7 +172,7 @@ TEST_F(ActivationOpTest, Relu) { #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) +#if defined(USE_CUDA) || defined(USE_COREML) TEST_F(ActivationOpTest, Sigmoid_fp16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -262,7 +262,7 @@ TEST_F(ActivationOpTest, Relu_fp16) { } #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL) +#if defined(USE_CUDA) || defined(USE_DNNL) TEST_F(ActivationOpTest, Sigmoid_bfloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -299,8 +299,6 @@ TEST_F(ActivationOpTest, Sigmoid_bfloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #elif USE_DNNL execution_providers.push_back(DefaultDnnlExecutionProvider()); #endif @@ -339,8 +337,6 @@ TEST_F(ActivationOpTest, Tanh_bfloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #elif USE_DNNL execution_providers.push_back(DefaultDnnlExecutionProvider()); #endif @@ -379,14 +375,12 @@ TEST_F(ActivationOpTest, Relu_bfloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #elif USE_DNNL execution_providers.push_back(DefaultDnnlExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -#endif // USE_CUDA || USE_ROCM || USE_DNNL +#endif // USE_CUDA || USE_DNNL #if defined(USE_DNNL) TEST_F(ActivationOpTest, LeakyRelu_bfloat16) { diff --git a/onnxruntime/test/providers/cpu/controlflow/if_test.cc b/onnxruntime/test/providers/cpu/controlflow/if_test.cc index 31b5618180bf7..d13371600389f 100644 --- a/onnxruntime/test/providers/cpu/controlflow/if_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/if_test.cc @@ -246,14 +246,11 @@ void RunTest(bool condition_value, excluded_providers.insert(kTensorrtExecutionProvider); } if (options.mixed_execution_providers) { - // we want the GPU (CUDA/ROCm) provider to be first, and the CPU provider second. all except the If should run on + // we want the GPU (CUDA) provider to be first, and the CPU provider second. all except the If should run on // GPU given that, which creates the scenario where we need to copy to/from CPU to execute the If node correctly. std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#endif -#ifdef USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif execution_providers.push_back(DefaultCpuExecutionProvider()); @@ -295,7 +292,7 @@ TEST(If, NoShapeInMainGraph_ShapeInSubgraph_False) { RunTest(false, options, false); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(If, MixedExecutionProviders) { RunOptions options{}; options.mixed_execution_providers = true; @@ -316,7 +313,7 @@ TEST(If, MixedExecutionProvidersNoShapeInSubgraph) { options.include_dim_values_in_subgraph = false; RunTest(true, options); } -#endif // defined(USE_CUDA) || defined(USE_ROCM) +#endif // defined(USE_CUDA) TEST(If, SymbolicShapeInMainGraph_NoShapeInSubgraph_True) { RunOptions options; diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 0bed6b6e9abee..10affa538dfad 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -360,8 +360,6 @@ void RunTest(int64_t max_iterations, std::vector> execution_providers; #if defined(USE_CUDA) execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif defined(USE_ROCM) - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif execution_providers.push_back(DefaultCpuExecutionProvider()); @@ -1042,8 +1040,8 @@ TEST(Loop, IterationCountAsOutput) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -#if defined(USE_CUDA) || defined(USE_ROCM) -// test that when part of the subgraph run on CUDA/ROCm it executes successfully +#if defined(USE_CUDA) +// test that when part of the subgraph run on CUDA it executes successfully TEST(Loop, MixedExecutionProviders) { RunOptions options{}; options.mixed_execution_providers = true; diff --git a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc index 6bf2fc63ab165..c7de8d4cba83d 100644 --- a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc @@ -412,8 +412,6 @@ static void RunTest_v9(const std::string test_name, int64_t sequence_len, int64_ std::vector> execution_providers; #if defined(USE_CUDA) execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif defined(USE_ROCM) - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif execution_providers.push_back(DefaultCpuExecutionProvider()); @@ -1167,8 +1165,6 @@ void UnknownDimInSubgraphOutput(bool is_v8, bool mixed_execution_providers = fal std::vector> execution_providers; #if defined(USE_CUDA) execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif defined(USE_ROCM) - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif execution_providers.push_back(DefaultCpuExecutionProvider()); @@ -1181,7 +1177,7 @@ void UnknownDimInSubgraphOutput(bool is_v8, bool mixed_execution_providers = fal TEST_8_AND_9(UnknownDimInSubgraphOutput); -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(Scan, MixedExecutionProviders) { RunOptions options{}; options.is_v8 = false; diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc index a923df2cebe30..b44aff56f1153 100644 --- a/onnxruntime/test/providers/cpu/generator/random_test.cc +++ b/onnxruntime/test/providers/cpu/generator/random_test.cc @@ -37,7 +37,7 @@ TEST(Random, RandomNormal2DDouble) { // The expected_output is generated using std lib, which is used by CPU kernel only. // So we need to exclude other EPs here. Ditto for other places. test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); } void RunRandomNormalLike3DFloat(bool infer_dtype = false) { @@ -74,7 +74,7 @@ void RunRandomNormalLike3DFloat(bool infer_dtype = false) { // TensorRT does not support manual seed overrides and there will be result mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(Random, RandomNormalLike3DDouble) { @@ -112,7 +112,7 @@ TEST(Random, RandomUniform1DFloat) { // TensorRT does not support manual seed overrides and there will be result mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } void RunRandomUniformLikeTest(bool infer_dtype = false) { @@ -146,7 +146,7 @@ void RunRandomUniformLikeTest(bool infer_dtype = false) { // TensorRT does not support seed parameter and there will be result mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(Random, RandomUniformLike2DDouble) { @@ -333,7 +333,7 @@ TEST(Random, MultinomialInvalidDtype) { test.Run(OpTester::ExpectResult::kExpectFailure, "Output type must be int32 or int64"); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) // We cannot call CUDA lib from UT, so just do some simple verification on output tensor. void RunRandomNormalGpuTest(const std::vector dims, const float mean, const float scale, const float seed, TensorProto_DataType dtype, bool is_random_like, bool infer_dtype) { diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index f9cbe46944d66..d3ea8552f60f4 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -380,17 +380,6 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); } -// ROCm doesn't support double -#ifndef USE_ROCM -TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_double) { - OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); - test.AddAttribute("equation", "iji->ji"); - test.AddInput("x", {2, 2, 2}, {1., 2., 3., 4., 1., 2., 3., 4.}); - test.AddOutput("o", {2, 2}, {1., 2., 3., 4.}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); -} -#endif - TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_int32) { OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); test.AddAttribute("equation", "iji->ji"); diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index cbb8ca43e8f06..3fb8cc3e1544f 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -40,8 +40,6 @@ void TestBinaryFloat16(const char* op_name, execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #elif USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif if (execution_providers.size() > 0) { OpTester tester(op_name, 14); @@ -56,8 +54,6 @@ void TestBinaryFloat16(const char* op_name, std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif if (enable_bf16 && execution_providers.size() > 0) { @@ -84,8 +80,6 @@ void TestUnaryFloat16(const char* op_name, execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #elif USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif if (execution_providers.size() > 0) { OpTester tester(op_name, opset); @@ -100,8 +94,6 @@ void TestUnaryFloat16(const char* op_name, std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif if (run_bf16 && execution_providers.size() > 0) { @@ -1439,7 +1431,7 @@ TEST(MathOpTest, Pow_float16_float16) { dims, {1.0f, 256.0f, 2.0f, 1.0f}, false); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) +#if defined(USE_CUDA) || defined(USE_COREML) TEST(MathOpTest, Pow_float_float16) { OpTester test("Pow", 12); std::vector dims{4}; @@ -1451,8 +1443,6 @@ TEST(MathOpTest, Pow_float_float16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #elif USE_COREML execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #endif @@ -4079,19 +4069,17 @@ TEST(ModOpTest, Fmod_float16_mixed_sign) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(ModOpTest, Fmod_bfloat16_mixed_sign) { OpTester test("Mod", 13); test.AddAttribute("fmod", 1); - // Due to BFloat16's precision, if the result is too small, it's not easy get pass for both CUDA and ROCm. + // Due to BFloat16's precision, if the result is too small, it's not easy get pass for both CUDA. test.AddInput("X", {4}, MakeBFloat16({8.0f, 5.0f, -8.0f, 8.0f})); test.AddInput("Y", {4}, MakeBFloat16({-3.4f, 8.0f, 3.4f, 5.0f})); test.AddOutput("Z", {4}, MakeBFloat16({1.2f, 5.f, -1.2f, 3.f})); std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 0e5a4dac465b1..d7d9d2994afa1 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -95,7 +95,7 @@ auto get_bias_value = [](const std::vector& bias_data, BiasType bias_type } // namespace -// Only CUDA, ROCM, CoreML and XNNPack kernels have float 16 support +// Only CUDA, CoreML and XNNPack kernels have float 16 support TEST(GemmOpTest, GemmNoTrans_f16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -195,7 +195,7 @@ TEST(GemmOpTest, GemmNoTrans_f16) { } } -// Only CUDA, ROCM and CoreML kernels have float 16 support +// Only CUDA and CoreML kernels have float 16 support TEST(GemmOpTest, GemmTransB_f16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -242,7 +242,7 @@ TEST(GemmOpTest, GemmTransB_f16) { } } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL) +#if defined(USE_CUDA) || defined(USE_DNNL) TEST(GemmOpTest, GemmNoTrans_bfloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -270,13 +270,6 @@ TEST(GemmOpTest, GemmNoTrans_bfloat16) { test.Config(run_with_tunable_op); #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.emplace_back(DefaultRocmExecutionProvider(/*test_tunable_op=*/true)); - test.ConfigEps(std::move(execution_providers)) - .RunWithConfig(); - - execution_providers.clear(); - execution_providers.emplace_back(DefaultRocmExecutionProvider(/*test_tunable_op=*/false)); #elif USE_DNNL execution_providers.emplace_back(DefaultDnnlExecutionProvider()); #endif diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index b7f2b5800560a..2e56aa6767598 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -365,7 +365,7 @@ TEST(MathOpTest, MatMulFloatType) { RunMatMulTest(7, false, true); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) || defined(USE_XNNPACK) +#if defined(USE_CUDA) || defined(USE_COREML) || defined(USE_XNNPACK) TEST(MathOpTest, MatMulFloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -445,7 +445,7 @@ TEST(MathOpTest, MatMulZeroKInt32Type) { RunMatMulZeroKTest(); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) || defined(USE_XNNPACK) +#if defined(USE_CUDA) || defined(USE_COREML) || defined(USE_XNNPACK) TEST(MathOpTest, MatMul_Float16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -482,7 +482,7 @@ TEST(MathOpTest, MatMul_Float16) { } #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL) +#if defined(USE_CUDA) || defined(USE_DNNL) TEST(MathOpTest, MatMul_bfloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -506,13 +506,6 @@ TEST(MathOpTest, MatMul_bfloat16) { test.Config(run_with_tunable_op); #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.emplace_back(DefaultRocmExecutionProvider(/*test_tunable_op=*/true)); - test.ConfigEps(std::move(execution_providers)) - .RunWithConfig(); - - execution_providers.clear(); - execution_providers.emplace_back(DefaultRocmExecutionProvider(/*test_tunable_op=*/false)); #elif USE_DNNL execution_providers.emplace_back(DefaultDnnlExecutionProvider()); #endif diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index 215203b31f49c..962a055b5fcbe 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -66,7 +66,7 @@ TEST(SoftmaxOperator, webgpu_nan) { } #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_XNNPACK) +#if defined(USE_CUDA) || defined(USE_XNNPACK) TEST(SoftmaxOperator, Simple_fp16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -95,7 +95,7 @@ TEST(SoftmaxOperator, Simple_fp16) { } #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL) +#if defined(USE_CUDA) || defined(USE_DNNL) TEST(SoftmaxOperator, Simple_bfloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -120,14 +120,12 @@ TEST(SoftmaxOperator, Simple_bfloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #elif USE_DNNL execution_providers.push_back(DefaultDnnlExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -#endif // USE_CUDA USE_ROCM USE_DNNL +#endif // USE_CUDA USE_DNNL TEST(SoftmaxOperator, LargeNumber) { // x = np.array([[0, 1, 2, 3], [10000, 10001, 10002, 10003]]).astype(np.float32) diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index ca1a3104e0bed..b1642161d0bb8 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -238,7 +238,7 @@ TEST_P(ModelTest, Run) { // when cuda or openvino is enabled, set it to a larger value for resolving random MNIST test failure if (model_path.find(ORT_TSTR("_MNIST")) > 0) { - if (provider_name == "cuda" || provider_name == "openvino" || provider_name == "rocm") { + if (provider_name == "cuda" || provider_name == "openvino") { per_sample_tolerance = 2.5e-2; relative_per_sample_tolerance = 1e-2; } @@ -331,9 +331,6 @@ TEST_P(ModelTest, Run) { cuda_options.Update(options); ortso.AppendExecutionProvider_CUDA_V2(*cuda_options); - } else if (provider_name == "rocm") { - OrtROCMProviderOptions ep_options; - ortso.AppendExecutionProvider_ROCM(ep_options); } #ifdef USE_DNNL else if (provider_name == "dnnl") { @@ -545,7 +542,6 @@ static constexpr ORT_STRING_VIEW provider_name_migraphx = ORT_TSTR("migraphx"); #endif static constexpr ORT_STRING_VIEW provider_name_openvino = ORT_TSTR("openvino"); static constexpr ORT_STRING_VIEW provider_name_cuda = ORT_TSTR("cuda"); -static constexpr ORT_STRING_VIEW provider_name_rocm = ORT_TSTR("rocm"); static constexpr ORT_STRING_VIEW provider_name_dnnl = ORT_TSTR("dnnl"); // For any non-Android system, NNAPI will only be used for ort model converter #if defined(USE_NNAPI) && defined(__ANDROID__) @@ -588,9 +584,6 @@ ::std::vector<::std::basic_string> GetParameterStrings() { #ifdef USE_CUDA provider_names[provider_name_cuda] = {opset7, opset8, opset9, opset10, opset11, opset12, opset13, opset14, opset15, opset16, opset17, opset18}; #endif -#ifdef USE_ROCM - provider_names[provider_name_rocm] = {opset7, opset8, opset9, opset10, opset11, opset12, opset13, opset14, opset15, opset16, opset17, opset18}; -#endif #ifdef USE_DNNL provider_names[provider_name_dnnl] = {opset10}; #endif @@ -663,46 +656,29 @@ ::std::vector<::std::basic_string> GetParameterStrings() { ORT_TSTR("operator_pow"), }; - static const ORTCHAR_T* cuda_rocm_flaky_tests[] = {ORT_TSTR("fp16_inception_v1"), - ORT_TSTR("fp16_shufflenet"), - ORT_TSTR("fp16_tiny_yolov2"), - ORT_TSTR("candy"), - ORT_TSTR("tinyyolov3"), - ORT_TSTR("mlperf_ssd_mobilenet_300"), - ORT_TSTR("mlperf_ssd_resnet34_1200"), - ORT_TSTR("tf_inception_v1"), - ORT_TSTR("faster_rcnn"), - ORT_TSTR("split_zero_size_splits"), - ORT_TSTR("convtranspose_3d"), - ORT_TSTR("fp16_test_tiny_yolov2-Candy"), - ORT_TSTR("fp16_coreml_FNS-Candy"), - ORT_TSTR("fp16_test_tiny_yolov2"), - ORT_TSTR("fp16_test_shufflenet"), - ORT_TSTR("keras2coreml_SimpleRNN_ImageNet"), - // models from model zoo. #26274: cuDNN frontend no valid engine - ORT_TSTR("YOLOv3"), - ORT_TSTR("YOLOv3-12"), - ORT_TSTR("YOLOv4"), - ORT_TSTR("SSD-MobilenetV1"), - ORT_TSTR("SSD-MobilenetV1-12")}; - - // For ROCm EP, also disable the following tests due to flakiness, - // mainly with precision issue and random memory access fault. - static const ORTCHAR_T* rocm_disabled_tests[] = {ORT_TSTR("bvlc_alexnet"), - ORT_TSTR("bvlc_reference_caffenet"), - ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), - ORT_TSTR("coreml_Resnet50_ImageNet"), - ORT_TSTR("mlperf_resnet"), - ORT_TSTR("mobilenetv2-1.0"), - ORT_TSTR("shufflenet"), - // models from model zoo - ORT_TSTR("AlexNet"), - ORT_TSTR("CaffeNet"), - ORT_TSTR("MobileNet v2-7"), - ORT_TSTR("R-CNN ILSVRC13"), - ORT_TSTR("ShuffleNet-v1"), - ORT_TSTR("version-RFB-320"), - ORT_TSTR("version-RFB-640")}; + static const ORTCHAR_T* cuda_flaky_tests[] = {ORT_TSTR("fp16_inception_v1"), + ORT_TSTR("fp16_shufflenet"), + ORT_TSTR("fp16_tiny_yolov2"), + ORT_TSTR("candy"), + ORT_TSTR("tinyyolov3"), + ORT_TSTR("mlperf_ssd_mobilenet_300"), + ORT_TSTR("mlperf_ssd_resnet34_1200"), + ORT_TSTR("tf_inception_v1"), + ORT_TSTR("faster_rcnn"), + ORT_TSTR("split_zero_size_splits"), + ORT_TSTR("convtranspose_3d"), + ORT_TSTR("fp16_test_tiny_yolov2-Candy"), + ORT_TSTR("fp16_coreml_FNS-Candy"), + ORT_TSTR("fp16_test_tiny_yolov2"), + ORT_TSTR("fp16_test_shufflenet"), + ORT_TSTR("keras2coreml_SimpleRNN_ImageNet"), + // models from model zoo. #26274: cuDNN frontend no valid engine + ORT_TSTR("YOLOv3"), + ORT_TSTR("YOLOv3-12"), + ORT_TSTR("YOLOv4"), + ORT_TSTR("SSD-MobilenetV1"), + ORT_TSTR("SSD-MobilenetV1-12")}; + static const ORTCHAR_T* openvino_disabled_tests[] = { ORT_TSTR("tf_mobilenet_v1_1.0_224"), ORT_TSTR("bertsquad"), @@ -827,13 +803,9 @@ ::std::vector<::std::basic_string> GetParameterStrings() { std::unordered_set> all_disabled_tests(std::begin(immutable_broken_tests), std::end(immutable_broken_tests)); - bool provider_cuda_or_rocm = provider_name == provider_name_cuda; - if (provider_name == provider_name_rocm) { - provider_cuda_or_rocm = true; - all_disabled_tests.insert(std::begin(rocm_disabled_tests), std::end(rocm_disabled_tests)); - } - if (provider_cuda_or_rocm) { - all_disabled_tests.insert(std::begin(cuda_rocm_flaky_tests), std::end(cuda_rocm_flaky_tests)); + bool provider_cuda = provider_name == provider_name_cuda; + if (provider_cuda) { + all_disabled_tests.insert(std::begin(cuda_flaky_tests), std::end(cuda_flaky_tests)); } else if (provider_name == provider_name_dml) { all_disabled_tests.insert(std::begin(dml_disabled_tests), std::end(dml_disabled_tests)); } else if (provider_name == provider_name_dnnl) { diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index a529d572d7cca..93ca22a16bf67 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -703,8 +703,8 @@ TEST(BatchNormTest, NonSpatial_Complicated) { 8); // opset-8 } -// Only CUDA and ROCm kernels have float 16 support -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) +// Only CUDA/CoreML kernels have float 16 support +#if defined(USE_CUDA) || defined(USE_COREML) TEST(BatchNormTest, BatchNorm2d_fp16) { vector X{-0.91221f, -0.283559f, 0.937637f, 2.09818f, -0.100199f, -0.608113f, 0.444562f, -1.07505f, 0.940591f, -0.922262f, 0.0931303f, 0.69611f, 1.55187f, 0.159808f, 0.914874f, -1.24856f, -1.98928f, -0.331621f, @@ -923,7 +923,7 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) { // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() test.Run(OpTester::ExpectResult::kExpectSuccess, "", // TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1 - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); } @@ -953,7 +953,7 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) { // exclude CUDA Execution Provider due to flakiness // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); } @@ -983,7 +983,7 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) { // Same exclusions as the opset 14 test test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); } diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 0847c15ba7cc6..6d6fedb3c9812 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -6,6 +6,7 @@ #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_COREML) || defined(USE_XNNPACK) || defined(USE_WEBGPU) #include "gtest/gtest.h" +#include "test/common/cuda_op_test_utils.h" #include "test/common/random_generator.h" #include "test/providers/provider_test_utils.h" #include "default_providers.h" diff --git a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc index 9e0516fd394ce..86ecee5be92dd 100644 --- a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc @@ -130,8 +130,8 @@ TEST(InstanceNormalizationOpTest, InstanceNormBatch2) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// Only CUDA and ROCm kernels have float 16 support -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) || defined(USE_WEBGPU) +// Only a few EPs have float 16 support +#if defined(USE_CUDA) || defined(USE_COREML) || defined(USE_WEBGPU) TEST(InstanceNormalizationOpTest, InstanceNormBatch1_fp16) { OpTester test("InstanceNormalization"); diff --git a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc index b2b7f1701107a..9be733e22f2e6 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc @@ -166,7 +166,7 @@ TEST(PoolFp16Test, MaxPool_DilationPadding_1d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolFp16Test, MaxPool_Dilation_2d) { @@ -223,7 +223,7 @@ TEST(PoolFp16Test, MaxPool_DilationPadding_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolFp16Test, MaxPool_Dilation_Ceil0_2d) { @@ -319,7 +319,7 @@ TEST(PoolTest, MaxPool_DilationPadding_3d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolBF16Test, AveragePool) { diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 1df640a84a64d..8d276b7300e37 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -474,7 +474,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_1d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_2d) { @@ -558,7 +558,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) { @@ -683,7 +683,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_3d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TYPED_TEST(PoolTest, GlobalMaxPool) { diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index c56aa3fb5feac..2ddb9d32cf196 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -881,7 +881,7 @@ TEST(ReductionOpTest, ReduceLogSumExp_float_no_reduction_keepdims) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(ReductionOpTest, ReduceLogSumExp_half) { OpTester test("ReduceLogSumExp"); test.AddAttribute("axes", std::vector{0, 2}); @@ -898,7 +898,7 @@ TEST(ReductionOpTest, ReduceLogSumExp_half) { test.AddOutput("reduced", {1, 2, 1}, FloatsToMLFloat16s({10.33174133f, 12.33174133f})); test.Run(); } -#endif // defined(USE_CUDA) || defined(USE_ROCM) +#endif // defined(USE_CUDA) TEST(ReductionOpTest, ReduceLogSumExp_int32) { OpTester test("ReduceLogSumExp"); @@ -1375,7 +1375,7 @@ TEST(ReductionOpTest, ReduceMax_double) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) +#if defined(USE_CUDA) || defined(USE_COREML) TEST(ReductionOpTest, ReduceMax_half) { OpTester test("ReduceMax"); test.AddAttribute("axes", std::vector{1, 2}); @@ -1392,7 +1392,7 @@ TEST(ReductionOpTest, ReduceMax_half) { test.AddOutput("reduced", {3, 1, 1}, FloatsToMLFloat16s({4.0f, 8.0f, 12.0f})); test.Run(); } -#endif // defined(USE_CUDA) || defined(USE_ROCM) +#endif // defined(USE_CUDA) TEST(ReductionOpTest, ReduceMax_int32) { OpTester test("ReduceMax"); @@ -2158,7 +2158,7 @@ TEST(ReductionOpTest, ReduceMin_double) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) +#if defined(USE_CUDA) || defined(USE_COREML) TEST(ReductionOpTest, ReduceMin_half) { OpTester test("ReduceMin"); test.AddAttribute("axes", std::vector{0, 2}); @@ -2175,7 +2175,7 @@ TEST(ReductionOpTest, ReduceMin_half) { test.AddOutput("reduced", {1, 2, 1}, FloatsToMLFloat16s({1.0f, 3.0f})); test.Run(); } -#endif // defined(USE_CUDA) || defined(USE_ROCM) +#endif // defined(USE_CUDA) TEST(ReductionOpTest, ReduceMin_int32) { OpTester test("ReduceMin"); @@ -2356,7 +2356,7 @@ TEST(ReductionOpTest, ReduceSum_int32) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) +#if defined(USE_CUDA) || defined(USE_COREML) TEST(ReductionOpTest, ReduceSumHalfHalf) { OpTester test("ReduceSum"); test.AddAttribute("keepdims", (int64_t)0); @@ -2448,7 +2448,7 @@ TEST(ReductionOpTest, ReduceSum_half_bert) { // Add more UTs for half as needed #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL) +#if defined(USE_CUDA) || defined(USE_DNNL) TEST(ReductionOpTest, ReduceSum_bfloat16) { #ifdef USE_DNNL if (!DnnlHasBF16Support()) { @@ -2465,19 +2465,15 @@ TEST(ReductionOpTest, ReduceSum_bfloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #elif USE_DNNL execution_providers.push_back(DefaultDnnlExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -#endif // USE_CUDA USE_ROCM USE_DNNL +#endif // USE_CUDA USE_DNNL // on CUDA - this UT, with axes {0,2}, will go thru cudnn lib only if ATenOp is not initialized -// on ROCM - miopen call succeeded, but results in data error, thus follow the same logic done in cudnn for now -// 4.2 doesn't run properly (data error), thus enable the UT only above 4.3 -#if defined(USE_CUDA) || (defined(USE_ROCM) && ROCM_VERSION >= 40300) +#if defined(USE_CUDA) TEST(ReductionOpTest, ReduceSumBFloat16_2) { OpTester test("ReduceSum", 14); test.AddAttribute("keepdims", (int64_t)0); @@ -2488,8 +2484,6 @@ TEST(ReductionOpTest, ReduceSumBFloat16_2) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -2595,7 +2589,7 @@ TEST(ReductionOpTest, ReduceSum_batch_by_seq_by_128) { } } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(ReductionOpTest, ReduceSum_batch_by_seq_by_30528) { test_apex_reduce_sum(4 * 128, 30528); test_apex_reduce_sum(4 * 512, 30528); @@ -3783,7 +3777,7 @@ TEST(ReductionOpTest, OptimizeShapeForFastReduce_ReduceDimWithZero1b) { // test that PrepareForReduce handles this case. Called by all reduction ops so any op can be used in the test TEST(ReductionOpTest, ReduceDimWithZero1) { // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr || DefaultRocmExecutionProvider().get() != nullptr) { + if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Expected output shape [{1,0,1}] did not match run output shape [{1,1,1}] for reduced"; } @@ -3834,7 +3828,7 @@ TEST(ReductionOpTest, OptimizeShapeForFastReduce_ReduceDimWithZero2) { TEST(ReductionOpTest, ReduceDimWithZero2) { // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr || DefaultRocmExecutionProvider().get() != nullptr) { + if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Can't reduce on dim with value of 0 if 'keepdims' is false. Invalid output shape would be produced. input_shape:{?,0,?}"; } @@ -6046,7 +6040,6 @@ void test_empty_set(const std::string& op, int opset, bool axes_as_input, float kMIGraphXExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider, - kRocmExecutionProvider, kTensorrtExecutionProvider, kWebGpuExecutionProvider, }); diff --git a/onnxruntime/test/providers/cpu/tensor/expand_test.cc b/onnxruntime/test/providers/cpu/tensor/expand_test.cc index 38e3bc3af6d6b..1680f21d781b7 100644 --- a/onnxruntime/test/providers/cpu/tensor/expand_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/expand_test.cc @@ -5,7 +5,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" -#if defined(ENABLE_STRIDED_TENSORS) && (defined(USE_CUDA) || defined(USE_ROCM)) +#if defined(ENABLE_STRIDED_TENSORS) && defined(USE_CUDA) #include "test/providers/kernel_compute_test_utils.h" #endif @@ -201,12 +201,10 @@ TEST(ExpandOpTest, Expand_scalar_int32) { test.Run(); } -#if defined(ENABLE_STRIDED_TENSORS) && (defined(USE_CUDA) || defined(USE_ROCM)) +#if defined(ENABLE_STRIDED_TENSORS) && defined(USE_CUDA) TEST(ExpandOpTest, Strided) { #ifdef USE_CUDA const char* provider = kCudaExecutionProvider; -#else // USE_ROCM - const char* provider = kRocmExecutionProvider; #endif // Generate contiguous output. { diff --git a/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc index 81e51375b9992..23b4424b1453a 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc @@ -9,7 +9,7 @@ #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" -#if defined(ENABLE_STRIDED_TENSORS) && (defined(USE_CUDA) || defined(USE_ROCM)) +#if defined(ENABLE_STRIDED_TENSORS) && defined(USE_CUDA) #include "test/providers/kernel_compute_test_utils.h" #endif @@ -216,7 +216,7 @@ void RunTestWrapper() { test8.Run(); } -#if defined(ENABLE_STRIDED_TENSORS) && (defined(USE_CUDA) || defined(USE_ROCM)) +#if defined(ENABLE_STRIDED_TENSORS) && defined(USE_CUDA) template void RunKernelComputeTest(std::initializer_list input_dims, std::initializer_list indices_dims, std::initializer_list indices_strides = {}, bool has_axis = false, @@ -228,8 +228,6 @@ void RunKernelComputeTest(std::initializer_list input_dims, std::initia GetData(input_dims, indices_dims, indices_strides, new_axis, input_data, indices_data, output_data); #ifdef USE_CUDA const char* provider = kCudaExecutionProvider; -#else // USE_ROCM - const char* provider = kRocmExecutionProvider; #endif KernelComputeTester test("GatherElements", provider); if (has_axis) test.AddAttribute("axis", axis); @@ -391,7 +389,7 @@ TEST(GatherElementsOpTest, IndicesOutOfBounds) { // skip QNN because it doesn't support out of bounds indices // skip WebGPU because it doesn't support out of bounds indices test.Run(OpTester::ExpectResult::kExpectFailure, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kOpenVINOExecutionProvider, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kOpenVINOExecutionProvider, kTensorrtExecutionProvider, kDmlExecutionProvider, kQnnExecutionProvider, kWebGpuExecutionProvider}); } @@ -413,7 +411,7 @@ TEST(GatherElementsOpTest, BigIndices) { test1.Run(); } -#if defined(ENABLE_STRIDED_TENSORS) && (defined(USE_CUDA) || defined(USE_ROCM)) +#if defined(ENABLE_STRIDED_TENSORS) && defined(USE_CUDA) TEST(GatherElementsOpTest, Strided_float) { RunKernelComputeTestWrapper(); } TEST(GatherElementsOpTest, Strided_double) { RunKernelComputeTestWrapper(); } diff --git a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc index be79a6d29d539..997ff2869592c 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc @@ -107,7 +107,7 @@ TEST(GatherOpTest, Gather_invalid_index_cpu) { .RunWithConfig(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(GatherOpTest, Gather_invalid_index_gpu) { OpTester test("Gather"); // Invalid index 3. data[3] does not exist. @@ -126,8 +126,6 @@ TEST(GatherOpTest, Gather_invalid_index_gpu) { test #if defined(USE_CUDA) .ConfigEp(DefaultCudaExecutionProvider()) -#else - .ConfigEp(DefaultRocmExecutionProvider()) #endif .RunWithConfig(); } @@ -440,9 +438,6 @@ TEST(ShrunkenGatherOpTest, ShrunkenGather_PositiveAxis) { #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); #endif -#ifdef USE_ROCM - execution_providers.emplace_back(DefaultRocmExecutionProvider()); -#endif OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain); test.AddAttribute("axis", 0LL); @@ -464,9 +459,6 @@ TEST(ShrunkenGatherOpTest, ShrunkenGather_NegativeAxis) { #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); #endif -#ifdef USE_ROCM - execution_providers.emplace_back(DefaultRocmExecutionProvider()); -#endif OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain); test.AddAttribute("axis", -1LL); @@ -488,9 +480,6 @@ TEST(ShrunkenGatherOpTest, ShrunkenGather_InvalidIndicesRank) { #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); #endif -#ifdef USE_ROCM - execution_providers.emplace_back(DefaultRocmExecutionProvider()); -#endif OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain); test.AddAttribute("axis", 0LL); @@ -512,9 +501,6 @@ TEST(ShrunkenGatherOpTest, ShrunkenGather_InvalidInputRank) { #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); #endif -#ifdef USE_ROCM - execution_providers.emplace_back(DefaultRocmExecutionProvider()); -#endif OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain); test.AddAttribute("axis", 0LL); diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 18eec7d1b42a3..c0325c07bab5e 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -362,8 +362,8 @@ TEST(DequantizeLinearOpTest, Per_Channel_Axis_1_int32) { 0, 4, 16, 48, 0, 20, 80, 240}); // Disable Tensorrt EP due to error, only activation types allowed as input to this layer. - // Disable CUDA, ROCm EP, there is no implementation for int32_t. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kCudaExecutionProvider, kRocmExecutionProvider}); + // Disable CUDA EP, there is no implementation for int32_t. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kCudaExecutionProvider}); } // 1d zero & scale with uint8 broadcast axis -2 (-2 resolves to axis 0) diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index f3b0695bdbd9c..c47bdd31f8458 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -109,9 +109,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA | WEBGPU: result mismatch due to not implementing NHWC support // TensorRT: results mismatch - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider, kWebGpuExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kWebGpuExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_uint8) { @@ -140,9 +139,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_int8) { @@ -198,11 +196,10 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support - // ROCm: results mismatch // DML: results mismatch test.Run( OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8) { @@ -283,9 +280,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA | WEBGPU: result mismatch due to not implementing NHWC support - // ROCm: results mismatch // TRT: Segmentation fault in A100 - std::unordered_set excluded_providers({kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kWebGpuExecutionProvider}); + std::unordered_set excluded_providers({kCudaExecutionProvider, kCudaNHWCExecutionProvider, kWebGpuExecutionProvider}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100(excluded_providers)); } @@ -310,9 +306,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); test.SetOutputAbsErr("Y", 1.0f); // CUDA: result mismatch due to not implementing NHWC support - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { @@ -550,9 +545,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_uin test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); }; run_test(false); @@ -650,10 +644,9 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); test.SetOutputAbsErr("Y", 1.0f); // CUDA: result mismatch due to not implementing NHWC support - // ROCm: results mismatch // DML: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) { @@ -763,9 +756,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y, false, .0f, 1.0f); // CUDA: result mismatch due to not implementing NHWC support - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); }; run_test(false); @@ -2239,12 +2231,12 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear) { 36.590908f, 76.59091f, 116.59091f}; // Nchw is not supported by CUDA Resize implementation - InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; + InlinedVector excluded_eps = {kCudaExecutionProvider}; TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 5, 8, 3}, X, {1, 4, 5, 3}, Y, excluded_eps); } TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { - InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; + InlinedVector excluded_eps = {kCudaExecutionProvider}; { std::vector X(16); std::iota(X.begin(), X.end(), uint8_t(0)); @@ -2391,7 +2383,7 @@ TEST(ResizeOpTest, Antialias_NHWCBicubic_ExcludeOutside) { 46.606194f, 19.878183f, 43.87818f, 21.358122f, 45.35812f, 22.907503f, 46.907505f, 24.387442f, 48.387444f}; - InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; + InlinedVector excluded_eps = {kCudaExecutionProvider}; TestAntialiasing({{"mode", "cubic"}, {"exclude_outside", "0"}}, {1, 4, 6, 2}, X, {1, 8, 4, 2}, Y, excluded_eps); } @@ -2487,7 +2479,7 @@ TEST(ResizeOpTest, NoAntialias_AlignCorners_Cubic_Floor_NHWC) { 23.0000f, 24.0000f, }; // clang-format on - InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; + InlinedVector excluded_eps = {kCudaExecutionProvider}; TestAntialiasing( {{"antialias", "0"}, {"coordinate_transformation_mode", "align_corners"}, @@ -2519,7 +2511,7 @@ TEST(ResizeOpTest, Antialias_Linear_AlignCorners) { 187.08333f, 195.91667f, 198.41667f, 205.91667f, 208.41667f, 217.25f, 219.75f, 227.25f, 229.75f, 238.58333f, 241.08333f, 248.58333f, 251.08333f}; - InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; + InlinedVector excluded_eps = {kCudaExecutionProvider}; TestAntialiasing( {{"mode", "linear"}, {"exclude_outside", "0"}, {"coordinate_transformation_mode", "align_corners"}}, {4, 1, 4, 4, 4}, X, {4, 1, 3, 2, 2}, Y, excluded_eps); diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 479a515403c74..56856211c39b3 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -268,7 +268,7 @@ static void scatter_invalid_index(const char* op_name, int op_version) { test.AddOutput("y", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f}); test.Run(OpTester::ExpectResult::kExpectFailure, "indices element out of data bounds, idx=4 must be within the inclusive range [-4,3]", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}); } TEST(Scatter, InvalidIndex) { diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index 688b2cd39c8fb..3e50b23353cb8 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -152,7 +152,7 @@ void RunTestWrapper() { RunTest({}, {}); #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) +#if defined(USE_CUDA) || defined(USE_WEBGPU) // _TileMemcpyKernelFromInput, vectorized 4 RunTest({256, 512}, {3, 1}); @@ -263,7 +263,7 @@ TEST(TensorOpTest, TileStringType) { RunTestWrapper(); } TEST(TensorOpTest, TileBoolType) { RunTestWrapperForBool(); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) +#if defined(USE_CUDA) || defined(USE_WEBGPU) TEST(TensorOpTest, TileMLFloat16Type) { RunTestWrapper(); } #endif diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 73a5bce768a2a..00449cd442a32 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -771,11 +771,9 @@ TEST(TransposeOpTest, DoTransposeEltWise) { #if USE_CUDA constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; -#elif USE_ROCM -constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider; #endif -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) static void TestTranspose( const std::vector& perm, const std::vector& x_dims, @@ -867,7 +865,7 @@ TEST(TransposeOpTest, TransposeBigMLFloat16) { // Exercises CanUse_cublasTransp const std::vector Y_dims{1, 1449, 1449, 3}; TestTransposeMLFloat16(perm, X_dims, Y_dims); } -#endif // defined(USE_CUDA) || defined(USE_ROCM) +#endif // defined(USE_CUDA) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc b/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc index 3ac8053aef95e..10dd14be4ce92 100644 --- a/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc @@ -91,9 +91,8 @@ TEST(UpsampleOpTest, NhwcUpsampleOpNearestTest) { test.AddOutput("Y", {N, (int64_t)(H * scales[1]), (int64_t)(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOpNearestTest_int32) { @@ -174,10 +173,8 @@ TEST(UpsampleOpTest, NhwcUpsampleOpNearestTest_int32) { test.AddOutput("Y", {N, (int64_t)(H * scales[1]), (int64_t)(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support - // TensorRT: results mismatch - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOpNearestTest_uint8) { @@ -259,9 +256,8 @@ TEST(UpsampleOpTest, NhwcUpsampleOpNearestTest_uint8) { test.AddOutput("Y", {N, (int64_t)(H * scales[1]), (int64_t)(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOpNearest2XTest) { @@ -335,9 +331,8 @@ TEST(UpsampleOpTest, NhwcUpsampleOpNearest2XTest) { test.AddOutput("Y", {N, (int64_t)(H * scales[1]), (int64_t)(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOpNearest222XTest) { @@ -441,9 +436,8 @@ TEST(UpsampleOpTest, NhwcUpsampleOpNearest222XTest) { test.AddOutput("Y", {(int64_t)(N * scales[0]), (int64_t)(H * scales[1]), (int64_t)(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOpNearest15XTest) { @@ -513,9 +507,8 @@ TEST(UpsampleOpTest, NhwcUpsampleOpNearest15XTest) { test.AddOutput("Y", {N, (int64_t)(H * scales[1]), (int64_t)(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOpNearestTest_NoScale) { @@ -615,9 +608,8 @@ TEST(UpsampleOpTest, NhwcUpsampleOpNearest2XTest_int32) { test.AddOutput("Y", {N, (int64_t)(H * scales[1]), (int64_t)(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOp4DBilinearTest) { @@ -691,9 +683,8 @@ TEST(UpsampleOpTest, NhwcUpsampleOp4D1CBilinearTest) { test.AddOutput("Y", {N, (int64_t)(H * scales[1]), (int64_t)(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(UpsampleOpTest, NhwcUpsampleOp4DBilinearTest) { @@ -765,9 +756,8 @@ TEST(UpsampleOpTest, NhwcUpsampleOp4DBilinearTest) { test.AddOutput("Y", {N, (int64_t)(H * scales[1]), (int64_t)(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOp2DBilinearTest) { @@ -885,9 +875,8 @@ TEST(UpsampleOpTest, NhwcUpsampleOp4DBilinearTest_int32) { test.AddOutput("Y", {N, (int64_t)(H * scales[1]), (int64_t)(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOpNearestTest_1D) { @@ -985,9 +974,8 @@ TEST(UpsampleOpTest, NhwcUpsampleOpNearest2XTest_opset9) { test.AddOutput("Y", {N, (int64_t)(H * scales[1]), (int64_t)(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch - // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider}); } } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc index 70c7a5b2bcdcb..5deef01cd783e 100644 --- a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc +++ b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc @@ -22,10 +22,17 @@ namespace test { // --------- Helpers --------- +// cuda errors are sticky and may affect subsequent API calls. +// we want to clear the error if when supported check fails. +void ClearCudaError() { + ORT_IGNORE_RETURN_VALUE(::cudaGetLastError()); +} + static bool IsCudaMemPoolSupported() { int ort_cuda_rt_version = 0; cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version); if (cuda_status != cudaSuccess) { + ClearCudaError(); return false; } @@ -36,6 +43,7 @@ static bool IsCudaMemPoolSupported() { int ort_cuda_driver_version = 0; cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version); if (cuda_status != cudaSuccess) { + ClearCudaError(); return false; } @@ -65,9 +73,10 @@ static bool IsCudaMemPoolSupported() { cudaMemPool_t pool; auto cuda_error = cudaMemPoolCreate(&pool, &props); if (cuda_error != cudaSuccess) { + ClearCudaError(); return false; } - cuda_error = cudaMemPoolDestroy(pool); + ORT_IGNORE_RETURN_VALUE(cudaMemPoolDestroy(pool)); return true; } @@ -80,7 +89,9 @@ static ::cudaStream_t NewCudaStream() { } static void DestroyCudaStream(::cudaStream_t s) { - if (s) (void)::cudaStreamDestroy(s); + if (s) { + EXPECT_EQ(cudaSuccess, ::cudaStreamDestroy(s)); + } } static void TouchDevice(void* p, size_t bytes, ::cudaStream_t s, unsigned char value = 0xAB) { diff --git a/onnxruntime/test/providers/kernel_compute_test_utils.cc b/onnxruntime/test/providers/kernel_compute_test_utils.cc index 93e688570631e..9f75797936f03 100644 --- a/onnxruntime/test/providers/kernel_compute_test_utils.cc +++ b/onnxruntime/test/providers/kernel_compute_test_utils.cc @@ -32,15 +32,6 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { ASSERT_STATUS_OK(execution_providers.Add(ep_type, std::move(cuda_ep))); } #endif -#ifdef USE_ROCM - if (provider_ == kRocmExecutionProvider) { - auto rocm_ep = DefaultRocmExecutionProvider(); - ep_type = rocm_ep->Type(); - auto rocm_transfer = rocm_ep->GetDataTransfer(); - ASSERT_STATUS_OK(dtm.RegisterDataTransfer(std::move(rocm_transfer))); - ASSERT_STATUS_OK(execution_providers.Add(ep_type, std::move(rocm_ep))); - } -#endif const auto& logger = DefaultLoggingManager().DefaultLogger(); Model model("test", false, ModelMetaData(), ORT_TSTR(""), IOnnxRuntimeOpSchemaRegistryList(), @@ -56,8 +47,8 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { if (provider_ == kCpuExecutionProvider || data.is_cpu_data_) { initializer_map[name] = data.value_; } -#if defined(USE_CUDA) || defined(USE_ROCM) - if ((provider_ == kCudaExecutionProvider || provider_ == kRocmExecutionProvider) && !data.is_cpu_data_) { +#if defined(USE_CUDA) + if (provider_ == kCudaExecutionProvider && !data.is_cpu_data_) { const Tensor& tensor = data.value_.Get(); Tensor gpu_tensor(tensor.DataType(), tensor.Shape(), diff --git a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc index ca3b9ee8c5a9b..9f85d37f1e2f4 100644 --- a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc +++ b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc @@ -188,6 +188,7 @@ TEST(MIGraphXExecutionProviderTest, canEvalArgument) { ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true); } +#if defined(WIN32) static bool SessionHasEp(Ort::Session& session, const char* ep_name) { // Access the underlying InferenceSession. const OrtSession* ort_session = session; @@ -203,7 +204,6 @@ static bool SessionHasEp(Ort::Session& session, const char* ep_name) { return has_ep; } -#if defined(WIN32) // Tests autoEP feature to automatically select an EP that supports the GPU. // Currently only works on Windows. TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) { diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index d8cc56d738175..af9706855ee3c 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -203,6 +203,48 @@ TEST_P(TypeTests, IOTypes) { } } +TEST(NvExecutionProviderTest, TestSessionOutputs) { + /* + * Model #1: + * + * "input" ---> TopK --- + * |---> "scores" + * |--- Less ---> "Less_output_0" + * |--- Div ---> "Div_output_0" + * |--- Mod ---> "labels" + */ + { + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + + auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 4); + } + + /* + * Model #2: + * + * "X" ---> Dropout ---> MatMul ---> "Y" + * ^ | + * | | + * "W" ------ ----> Can't be graph's output + * + */ + { + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + + auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 1); + } +} + INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests, ::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, diff --git a/onnxruntime/test/providers/qnn/README.md b/onnxruntime/test/providers/qnn/README.md new file mode 100644 index 0000000000000..c3d0c720a1aa4 --- /dev/null +++ b/onnxruntime/test/providers/qnn/README.md @@ -0,0 +1,70 @@ +# ONNX Runtime QNN Execution Provider Tests +## Overview +1. The `onnxruntime/test/providers/qnn` directory contains integration tests for the Qualcomm Neural Network (QNN) execution provider. +2. Most testcases run an ONNX model through the QNN-EP, then verifies the inference result against the one on CPU-EP + +## Building the Tests +The tests are built as part of the regular ONNX Runtime build. After a successful build you will have an executable named +- onnxruntime_provider_test.exe (Windows) +- onnxruntime_provider_test (Linux/macOS) + +## Running the Tests +1. QNN supports several backends. You can use the standard Google‑Test syntax for filtering: + - `onnxruntime_provider_test.exe --gtest_filter=QnnCPUBackendTests.*` + - `onnxruntime_provider_test.exe --gtest_filter=QnnHTPBackendTests.*` + - `onnxruntime_provider_test.exe --gtest_filter=QnnGPUBackendTests.*` + - `onnxruntime_provider_test.exe --gtest_filter=QnnIRBackendTests.*` +2. Saving Test Artifacts + - For debugging it is often helpful to keep the intermediate files that the tests generate. The following environment + variables are recognized by the test binary: + - `QNN_DUMP_ONNX`: Saves the input ONNX model used for the test + - `QNN_DUMP_JSON`: Save json qnn graph with provider_option `dump_json_qnn_graph` + - `QNN_DUMP_DLC`: Saves the compiled QNN DLC file by specifying the provider_option `backend_path` to `QnnIr.dll` + - The artifacts will be saved to a directory named with `_` + ``` + . + ├── QnnCPUBackendTests_BatchNorm2D_fp32 # RunQnnModelTest + │ ├── dumped_f32_model.onnx # float32 ONNX model + │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc + │ └── QNNExecutionProvider_QNN_XXXX_X_X.json + ├── QnnHTPBackendTests_BatchNorm_FP16 # TestFp16ModelAccuracy + │ ├── dumped_f16_model.onnx # float16 ONNX model + │ ├── dumped_f32_model.onnx # float32 ONNX model + │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc + │ └── QNNExecutionProvider_QNN_XXXX_X_X.json + └── QnnHTPBackendTests_BatchNorm2D_U8U8S32 # TestQDQModelAccuracy + ├── dumped_f32_model.onnx # float32 ONNX model + ├── dumped_qdq_model.onnx # QDQ ONNX model + ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc + └── QNNExecutionProvider_QNN_XXXX_X_X.json + + # All artifact files are placed under the current working directory from which the test binary is invoked. + ``` +3. Verbose + - `QNN_VERBOSE`: Sets the ONNX Runtime log level to `ORT_LOGGING_LEVEL_VERBOSE` + +4. You can enable any combination of these environment variables, for example: + - On Linux/macOS + ```bash + export QNN_DUMP_ONNX=1 + export QNN_DUMP_JSON=1 + export QNN_DUMP_DLC=1 + export QNN_VERBOSE=1 + ``` + - On Windows + ```cmd + set QNN_DUMP_ONNX=1 + set QNN_DUMP_JSON=1 + set QNN_DUMP_DLC=1 + set QNN_VERBOSE=1 + ``` + ```ps1 + $Env:QNN_DUMP_ONNX = "1" + $Env:QNN_DUMP_JSON = "1" + $Env:QNN_DUMP_DLC = "1" + $Env:QNN_VERBOSE = "1" + ``` + +# Note +- An issue on QNN backends can prevent the test artifacts from being successfully saved. +- The `onnxruntime_provider_test.exe` does not automatically delete the artifact directories, so you may want to prune them after a debugging session. diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 1c70f4012090e..15a9132aaa16c 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -101,6 +101,12 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err, logging::Severity log_severity, bool verify_outputs, std::function* ep_graph_checker) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_json() || + QNNTestEnvironment::GetInstance().dump_dlc()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } EPVerificationParams verification_params; verification_params.ep_node_assignment = expected_ep_assignment; verification_params.fp32_abs_err = fp32_abs_err; @@ -110,6 +116,10 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, @@ -123,7 +133,27 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); + + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, dump_path)); + } + TryEnableQNNSaver(provider_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + provider_options["dump_qnn_ir_dlc"] = "1"; + provider_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + provider_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + provider_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = output_dir.string(); + } RunAndVerifyOutputsWithEP(AsByteSpan(model_data.data(), model_data.size()), "QNN_EP_TestLogID", QnnExecutionProviderWithOptions(provider_options), helper.feeds_, verification_params, @@ -134,11 +164,21 @@ void RunQnnModelTestHTPNoVerify(const GetTestModelFn& build_test_case, ProviderO int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, logging::Severity log_severity, std::function* ep_graph_checker) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_dlc() || + QNNTestEnvironment::GetInstance().dump_json()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, @@ -152,7 +192,27 @@ void RunQnnModelTestHTPNoVerify(const GetTestModelFn& build_test_case, ProviderO // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); + + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, dump_path)); + } + TryEnableQNNSaver(provider_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + provider_options["dump_qnn_ir_dlc"] = "1"; + provider_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + provider_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + provider_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = output_dir.string(); + } SessionOptions so; so.session_logid = "QNN_EP_TestLogID"; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index aeb3a9a114871..4d4f795d161b1 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -499,6 +499,77 @@ struct QDQTolerance { float value; }; +class QNNTestEnvironment { + public: + // Delete copy constructor and assignment operator + QNNTestEnvironment(const QNNTestEnvironment&) = delete; + QNNTestEnvironment& operator=(const QNNTestEnvironment&) = delete; + + // Static method to get the singleton instance + static QNNTestEnvironment& GetInstance() { + static QNNTestEnvironment instance; + return instance; + } + + bool dump_onnx() const { return dump_onnx_; } + bool dump_json() const { return dump_json_; } + bool dump_dlc() const { return dump_dlc_; } + bool verbose() const { return verbose_; } + + std::filesystem::path CreateTestcaseDirs() { + std::string test_suite_name = ::testing::UnitTest::GetInstance()->current_test_info()->test_suite_name(); + std::string test_name = ::testing::UnitTest::GetInstance()->current_test_info()->name(); + std::filesystem::path output_dir = std::filesystem::current_path() / (test_suite_name + "_" + test_name); + std::filesystem::create_directories(output_dir); + + return output_dir; + } + + private: + // Private constructor for singleton + QNNTestEnvironment() { + ParseEnvironmentVars(); + } + + // Helper function to check if an environment variable is set + bool IsEnvVarSet(const char* name) { + const char* value = std::getenv(name); + if (value == nullptr) { + return false; + } + + // Consider the variable set if it's not empty and not "0" + return *value != '\0' && *value != '0'; + } + + void ParseEnvironmentVars() { + if (IsEnvVarSet("QNN_DUMP_ONNX")) { + std::cout << "[QNN only] ONNX model dumping enabled via environment variable." << std::endl; + dump_onnx_ = true; + } + + if (IsEnvVarSet("QNN_DUMP_JSON")) { + std::cout << "[QNN only] Json QNN Graph dumping enabled via environment variable." << std::endl; + dump_json_ = true; + } + + if (IsEnvVarSet("QNN_DUMP_DLC")) { + std::cout << "[QNN only] DLC dumping enabled via environment variable." << std::endl; + dump_dlc_ = true; + } + + if (IsEnvVarSet("QNN_VERBOSE")) { + std::cout << "Verbose enabled via environment variable." << std::endl; + verbose_ = true; + } + } + + bool dump_onnx_ = false; + bool dump_json_ = false; + bool dump_dlc_ = false; + bool verbose_ = false; +}; + /** * Tests the accuracy of a QDQ model on QNN EP by runnning 3 inferences: * @@ -529,15 +600,21 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}, std::function* qnn_ep_graph_checker = nullptr) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_dlc() || + QNNTestEnvironment::GetInstance().dump_json()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); - - // Uncomment to dump LOGGER() output to stdout. - // logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -551,8 +628,11 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); - // Uncomment to save f32 model to disk for debugging. - // ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, ToPathString("cmp_accuracy.f32.onnx"))); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float32 model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, dump_path)); + } // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; @@ -594,11 +674,27 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(qdq_model.MainGraph().Resolve()); qdq_model.ToProto().SerializeToString(&qdq_model_data); - // Uncomment to save QDQ model to disk for debugging. - // ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, ToPathString("cmp_accuracy.qdq.onnx"))); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_qdq_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx QDQ model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, dump_path)); + } bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + qnn_options["dump_qnn_ir_dlc"] = "1"; + qnn_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + qnn_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + qnn_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + qnn_options["dump_json_qnn_graph"] = "1"; + qnn_options["json_qnn_graph_dir"] = output_dir.string(); + } std::vector qnn_qdq_outputs; if (!qnn_ctx_model_path.empty()) { onnx::ModelProto model_proto; @@ -743,11 +839,21 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, logging::Severity log_severity = logging::Severity::kERROR, const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_dlc() || + QNNTestEnvironment::GetInstance().dump_json()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -760,6 +866,12 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float32 model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, dump_path)); + } + // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; InferenceModel(f32_model_data, "f32_model_logger", {}, ExpectedEPNodeAssignment::All, @@ -796,8 +908,27 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, ASSERT_STATUS_OK(f16_model.MainGraph().Resolve()); f16_model.ToProto().SerializeToString(&f16_model_data); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f16_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float16 model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(f16_model, dump_path)); + } + bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + qnn_options["dump_qnn_ir_dlc"] = "1"; + qnn_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + qnn_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + qnn_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + qnn_options["dump_json_qnn_graph"] = "1"; + qnn_options["json_qnn_graph_dir"] = output_dir.string(); + } std::vector qnn_f16_outputs; if (!qnn_ctx_model_path.empty()) { onnx::ModelProto model_proto; diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 6a6545c68cb4f..dce0d570ec238 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -1,5 +1,6 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "onnxruntime_cxx_api.h" #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" @@ -18,6 +19,8 @@ using namespace std; using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; +extern std::unique_ptr ort_env; + namespace onnxruntime { namespace test { @@ -1360,5 +1363,49 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); } + +TEST(TensorrtExecutionProviderTest, TestSessionOutputs) { + /* + * Model #1: + * + * "input" ---> TopK --- + * |---> "scores" + * |--- Less ---> "Less_output_0" + * |--- Div ---> "Div_output_0" + * |--- Mod ---> "labels" + */ + { + OrtTensorRTProviderOptionsV2 provider_options; + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_TensorRT_V2(provider_options); + + auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 4); + } + + /* + * Model #2: + * + * "X" ---> Dropout ---> MatMul ---> "Y" + * ^ | + * | | + * "W" ------ ----> Can't be graph's output + * + */ + { + OrtTensorRTProviderOptionsV2 provider_options; + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_TensorRT_V2(provider_options); + + auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 1); + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py index bb65533c3d1e0..f16415d625b5d 100644 --- a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py +++ b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py @@ -139,11 +139,6 @@ def common_test_model_gemm( providers = ["CPUExecutionProvider"] if "CUDAExecutionProvider" in available_providers: providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] - elif "ROCMExecutionProvider" in available_providers: - providers = [ - ("ROCMExecutionProvider", {"tunable_op_enable": "1", "tunable_op_tuning_enable": "1"}), - ("CPUExecutionProvider", {}), - ] expected = (a.T if kwargs.get("transA", 0) else a) @ (b.T if kwargs.get("transB", 0) else b) expected *= kwargs.get("alpha", 1.0) @@ -341,29 +336,6 @@ def test_combinations(self, shapeA, shapeB, transA, transB): self.assertEqual(expected.dtype, got[0].dtype) assert_allclose(expected, got[0]) - @parameterized.parameterized.expand( - [ - ("FLOAT8E4M3FN", "FLOAT16", 0, 0), - ("FLOAT16", "FLOAT8E4M3FN", 0, 0), - ("FLOAT16", "FLOAT8E4M3FN", 0, 1), - ] - ) - @unittest.skipIf("ROCMExecutionProvider" not in available_providers, reason="Not running without ROCm.") - @unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0") - def test_model_rocm_gemm_float8_e4m3(self, a_float_name, b_float_name, transA, transB): - self.common_test_model_gemm( - a_float_name=a_float_name, - b_float_name=b_float_name, - c_float_name="FLOAT8E4M3FN", - rtol=0.5, - dtype=TensorProto.FLOAT16, - transA=0, - transB=transB, - scaleY=False, - alpha=10.0, - beta=0.0, - ) - if __name__ == "__main__": # TestFloat8Gemm8().test_model_gemm_float() diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 7f003453add89..768a97d7ed2bc 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -542,47 +542,6 @@ def run_advanced_test(cuda_lib): print("run advanced_test") run_advanced_test(cuda) - if "ROCMExecutionProvider" in onnxrt.get_available_providers(): - - def run_rocm_options_test(): - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["ROCMExecutionProvider"]) - self.assertIn("ROCMExecutionProvider", sess.get_providers()) - options = sess.get_provider_options() - - def test_get_and_set_option_with_values(option_name, option_values): - provider_options = sess.get_provider_options() - self.assertIn("ROCMExecutionProvider", provider_options) - rocm_options = options["ROCMExecutionProvider"] - self.assertIn(option_name, rocm_options) - for option_value in option_values: - rocm_options[option_name] = option_value - sess.set_providers(["ROCMExecutionProvider"], [rocm_options]) - new_provider_options = sess.get_provider_options() - self.assertEqual( - new_provider_options.get("ROCMExecutionProvider", {}).get(option_name), - str(option_value), - ) - - test_get_and_set_option_with_values("tunable_op_enable", ["1", "0"]) - - test_get_and_set_option_with_values("tunable_op_tuning_enable", ["1", "0"]) - - test_get_and_set_option_with_values("tunable_op_max_tuning_duration_ms", ["-1", "1"]) - - test_get_and_set_option_with_values("enable_hip_graph", ["1", "0"]) - - # test for user_compute_stream - option = options["ROCMExecutionProvider"] - option["user_compute_stream"] = "1" - sess.set_providers(["ROCMExecutionProvider"], [option]) - new_options = sess.get_provider_options() - new_option = new_options["ROCMExecutionProvider"] - self.assertEqual(new_option["user_compute_stream"], "1") - # set user_compute_stream will set has_user_compute_stream to 1 too - self.assertEqual(new_option["has_user_compute_stream"], "1") - - run_rocm_options_test() - def test_invalid_set_providers(self): with self.assertRaises(RuntimeError) as context: sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CPUExecutionProvider"]) @@ -675,9 +634,6 @@ def do_test_get_and_set_tuning_results(ep): if "CUDAExecutionProvider" in onnxrt.get_available_providers(): do_test_get_and_set_tuning_results("CUDAExecutionProvider") - if "ROCMExecutionProvider" in onnxrt.get_available_providers(): - do_test_get_and_set_tuning_results("ROCMExecutionProvider") - def test_run_model_with_optional_sequence_input(self): sess = onnxrt.InferenceSession(get_name("identity_opt.onnx")) x = [np.array([1, 2, 3, 4, 5]).astype(np.float32)] @@ -1799,11 +1755,6 @@ def check_failure(providers, provider_options): check_failure([("a", {1: 2})], [{3: 4}]) def test_register_custom_e_ps_library(self): - available_eps = C.get_available_providers() - # skip amd gpu build - if "ROCMExecutionProvider" in available_eps: - return - if sys.platform.startswith("win"): shared_library = os.path.abspath("test_execution_provider.dll") diff --git a/onnxruntime/test/python/quantization/test_quant_preprocess.py b/onnxruntime/test/python/quantization/test_quant_preprocess.py new file mode 100644 index 0000000000000..c93f081072f35 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quant_preprocess.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import tempfile +import unittest +from pathlib import Path + +import numpy as np +import onnx + +from onnxruntime.quantization.shape_inference import quant_pre_process + + +class TestUpsample(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory(prefix="ort.quant_preprocess_") + self.temp_path = Path(self.temp_dir.name) + + def tearDown(self): + self.temp_dir.cleanup() + + def build_upsample_model(self, input_shape=(1, 3, 32, 32)): + """ + Build a model with deprecated Upsample op (opset <= 10) for testing version conversion. + """ + input_tensor = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, input_shape) + output_shape = (input_shape[0], input_shape[1], input_shape[2] * 2, input_shape[3] * 2) + output_tensor = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, output_shape) + + # Create scales for upsample + scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32) + scales_initializer = onnx.numpy_helper.from_array(scales, "scales") + + upsample_node = onnx.helper.make_node( + "Upsample", + ["input", "scales"], + ["output"], + name="upsample_node", + mode="nearest", + ) + + graph = onnx.helper.make_graph( + [upsample_node], + "upsample_graph", + [input_tensor], + [output_tensor], + initializer=[scales_initializer], + ) + # Use opset 10 to trigger Upsample -> Resize conversion + opset_imports = [onnx.helper.make_opsetid("", 10)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return model + + def test_upsample_to_resize_conversion(self): + """ + Test that deprecated Upsample ops are converted to Resize ops. + """ + model = self.build_upsample_model() + input_path = self.temp_path / "input_model.onnx" + output_path = self.temp_path / "preprocessed_model.onnx" + + onnx.save_model(model, input_path) + + # Verify original model has Upsample op + self.assertEqual(model.graph.node[0].op_type, "Upsample") + self.assertEqual(model.opset_import[0].version, 10) + + quant_pre_process( + input_model=str(input_path), + output_model_path=str(output_path), + skip_optimization=True, + skip_onnx_shape=True, + skip_symbolic_shape=True, + ) + + self.assertTrue(output_path.exists()) + preprocessed_model = onnx.load(str(output_path)) + + # Verify Upsample was converted to Resize and opset was upgraded + node_types = [node.op_type for node in preprocessed_model.graph.node] + assert "Resize" in node_types + assert "Upsample" not in node_types + assert preprocessed_model.opset_import[0].version >= 11 + + +class TestClip(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory(prefix="ort.quant_preprocess_") + self.temp_path = Path(self.temp_dir.name) + + def tearDown(self): + self.temp_dir.cleanup() + + def build_clip_model(self, input_shape=(1, 3, 32, 32)): + """ + Build a model with Clip op using ai.onnx v6 for testing version conversion. + """ + input_tensor = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, input_shape) + output_tensor = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, input_shape) + + # Create min and max values for clip + min_val = np.array(0.0, dtype=np.float32) + max_val = np.array(6.0, dtype=np.float32) + min_initializer = onnx.numpy_helper.from_array(min_val, "min") + max_initializer = onnx.numpy_helper.from_array(max_val, "max") + + clip_node = onnx.helper.make_node( + "Clip", + ["input", "min", "max"], + ["output"], + name="clip_node", + ) + + graph = onnx.helper.make_graph( + [clip_node], + "clip_graph", + [input_tensor], + [output_tensor], + initializer=[min_initializer, max_initializer], + ) + # Use opset 6 to trigger version conversion + opset_imports = [onnx.helper.make_opsetid("", 6)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return model + + def test_clip_version_conversion(self): + """ + Test that Clip op from ai.onnx v6 is upgraded to v11 after quant_pre_process. + """ + model = self.build_clip_model() + input_path = self.temp_path / "input_clip_model.onnx" + output_path = self.temp_path / "preprocessed_clip_model.onnx" + + onnx.save_model(model, input_path) + + # Verify original model has Clip op with opset 6 + self.assertEqual(model.graph.node[0].op_type, "Clip") + self.assertEqual(model.opset_import[0].version, 6) + + quant_pre_process( + input_model=str(input_path), + output_model_path=str(output_path), + skip_optimization=True, + skip_onnx_shape=True, + skip_symbolic_shape=True, + ) + + self.assertTrue(output_path.exists()) + preprocessed_model = onnx.load(str(output_path)) + + # Verify Clip op is still present and opset was upgraded to v11 or higher + node_types = [node.op_type for node in preprocessed_model.graph.node] + assert "Clip" in node_types + assert preprocessed_model.opset_import[0].version >= 11 + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/parity_utilities.py b/onnxruntime/test/python/transformers/parity_utilities.py index 7066d5a5425cb..fa16f0e67a523 100644 --- a/onnxruntime/test/python/transformers/parity_utilities.py +++ b/onnxruntime/test/python/transformers/parity_utilities.py @@ -181,8 +181,6 @@ def create_ort_session(onnx_model_path, use_gpu=True, optimized=True, verbose=Fa if not optimized: execution_providers.append("MIGraphXExecutionProvider") - execution_providers.append("ROCMExecutionProvider") - execution_providers.append("CPUExecutionProvider") return InferenceSession(onnx_model_path, sess_options, providers=execution_providers) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index c4a92e959b273..a39b9c89e5898 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -47,10 +47,6 @@ #include #endif -#ifdef USE_ROCM -#include -#endif - #ifdef USE_DML #include #include @@ -294,14 +290,6 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod session_options.AppendExecutionProvider_Dnnl(dnnl_options); #else return; -#endif - } else if (provider_type == 3) { -#ifdef USE_ROCM - std::cout << "Running simple inference with rocm provider" << std::endl; - OrtROCMProviderOptions rocm_options; - session_options.AppendExecutionProvider_ROCM(rocm_options); -#else - return; #endif } else { std::cout << "Running simple inference with default provider" << std::endl; @@ -347,7 +335,7 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod } static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx"); -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_DML) static constexpr PATH_TYPE CUDA_GRAPH_ANNOTATION_MODEL_URI = TSTR("testdata/mul_1_dynamic.onnx"); #endif static constexpr PATH_TYPE MATMUL_MODEL_URI = TSTR("testdata/matmul_1.onnx"); @@ -1715,15 +1703,12 @@ TEST(CApiTest, test_custom_op_library) { #ifdef USE_CUDA TestInference(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, expected_values_y, 1, nullptr, lib_name.c_str()); -#elif USE_ROCM - TestInference(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, - expected_values_y, 3, nullptr, lib_name.c_str()); #elif USE_DML TestInference(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, expected_values_y, 4, nullptr, lib_name.c_str()); #else -TestInference(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, - expected_values_y, 0, nullptr, lib_name.c_str()); + TestInference(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, + expected_values_y, 0, nullptr, lib_name.c_str()); #endif } @@ -2103,27 +2088,6 @@ TEST(CApiTest, get_allocator_cuda) { } #endif -#ifdef USE_ROCM -TEST(CApiTest, get_allocator_rocm) { - Ort::SessionOptions session_options; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); - Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); - - Ort::MemoryInfo info_rocm("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); - Ort::Allocator rocm_allocator(session, info_rocm); - - auto allocator_info = rocm_allocator.GetInfo(); - ASSERT_TRUE(info_rocm == allocator_info); - void* p = rocm_allocator.Alloc(1024); - ASSERT_NE(p, nullptr); - rocm_allocator.Free(p); - - auto mem_allocation = rocm_allocator.GetAllocation(1024); - ASSERT_NE(nullptr, mem_allocation.get()); - ASSERT_EQ(1024U, mem_allocation.size()); -} -#endif - #if defined(USE_QNN) TEST(CApiTest, get_allocator_qnn_htp_shared) { @@ -2424,7 +2388,7 @@ TEST(CApiTest, io_binding_qnn_htp_shared) { #endif // defined(USE_QNN) -#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_DML) TEST(CApiTest, basic_cuda_graph) { [[maybe_unused]] const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; @@ -2445,19 +2409,6 @@ TEST(CApiTest, basic_cuda_graph) { cuda_options.Update(options_map); session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); -#elif defined(USE_ROCM) - // Enable hip graph in rocm provider option. - OrtROCMProviderOptions* rocm_options = nullptr; - ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); - std::unique_ptr - rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); - std::vector keys{"enable_hip_graph"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); - - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( - static_cast(session_options), - rel_rocm_options.get()) == nullptr); #elif defined(USE_DML) // Enable dynamic DML graph in DML provider option. session_options.AddConfigEntry("ep.dml.enable_graph_capture", "1"); @@ -2469,13 +2420,7 @@ TEST(CApiTest, basic_cuda_graph) { #endif Ort::Session session(*ort_env, MODEL_URI, session_options); -#if defined(USE_ROCM) -// local hipify -#define cudaMemcpy hipMemcpy -#define cudaMemcpyHostToDevice hipMemcpyHostToDevice -#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost - Ort::MemoryInfo info_mem("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); -#elif defined(USE_CUDA) || defined(USE_TENSORRT) +#if defined(USE_CUDA) || defined(USE_TENSORRT) Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); #elif defined(USE_DML) Ort::MemoryInfo info_mem("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault); @@ -2491,7 +2436,7 @@ TEST(CApiTest, basic_cuda_graph) { ASSERT_NE(input_data.get(), nullptr); -#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_TENSORRT) (void)cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); #elif defined(USE_DML) ComPtr input_resource; @@ -2524,7 +2469,7 @@ TEST(CApiTest, basic_cuda_graph) { // Check the values against the bound raw memory (needs copying from device to host first) std::array y_values; -#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_TENSORRT) (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); #elif defined(USE_DML) ComPtr output_resource; @@ -2538,7 +2483,7 @@ TEST(CApiTest, basic_cuda_graph) { // Replay the captured CUDA graph session.Run(Ort::RunOptions(), binding); -#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_TENSORRT) (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); #elif defined(USE_DML) DownloadDataFromDml(dml_objects, output_resource.Get(), gsl::make_span(output_cpu_bytes, sizeof(float) * y_values.size())); @@ -2549,7 +2494,7 @@ TEST(CApiTest, basic_cuda_graph) { // Change the input and replay the CUDA graph again. x_values = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}; -#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_TENSORRT) (void)cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); #elif defined(USE_DML) UploadDataToDml(dml_objects, input_resource.Get(), gsl::make_span(reinterpret_cast(x_values.data()), sizeof(float) * x_values.size())); @@ -2559,7 +2504,7 @@ TEST(CApiTest, basic_cuda_graph) { session.Run(Ort::RunOptions(), binding); -#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_TENSORRT) (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); #elif defined(USE_DML) DownloadDataFromDml(dml_objects, output_resource.Get(), gsl::make_span(output_cpu_bytes, sizeof(float) * y_values.size())); @@ -2571,14 +2516,9 @@ TEST(CApiTest, basic_cuda_graph) { // Clean up binding.ClearBoundInputs(); binding.ClearBoundOutputs(); -#if defined(USE_ROCM) -#undef cudaMemcpy -#undef cudaMemcpyHostToDevice -#undef cudaMemcpyDeviceToHost -#endif } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_DML) struct CudaGraphInputOutputData_0 { const std::array x_shape = {3, 2}; std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -2622,12 +2562,6 @@ static void RunWithCudaGraphAnnotation(T& cg_data, Ort::MemoryAllocation& input_data, Ort::MemoryAllocation& output_data, const char* cuda_graph_annotation) { -// a local hipify of select cuda symbols to avoid code duplication -#ifdef USE_ROCM -#define cudaMemcpy hipMemcpy -#define cudaMemcpyHostToDevice hipMemcpyHostToDevice -#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost -#endif #ifdef USE_DML Ort::SessionOptions session_options; Ort::Allocator allocator(session, info_mem); @@ -2731,11 +2665,6 @@ static void RunWithCudaGraphAnnotation(T& cg_data, // Clean up binding.ClearBoundInputs(); binding.ClearBoundOutputs(); -#ifdef USE_ROCM -#undef cudaMemcpy -#undef cudaMemcpyHostToDevice -#undef cudaMemcpyDeviceToHost -#endif } TEST(CApiTest, basic_cuda_graph_with_annotation) { @@ -2758,20 +2687,6 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) { session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); -#elif defined(USE_ROCM) - // Enable hip graph in rocm provider option. - OrtROCMProviderOptions* rocm_options = nullptr; - ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); - std::unique_ptr - rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); - std::vector keys{"enable_hip_graph"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); - - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( - static_cast(session_options), - rel_rocm_options.get()) == nullptr); - Ort::MemoryInfo info_mem("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); #endif Ort::Session session(*ort_env, CUDA_GRAPH_ANNOTATION_MODEL_URI, session_options); @@ -2820,34 +2735,10 @@ TEST(CApiTest, cuda_graph_with_shape_nodes) { } #endif // defined(USE_CUDA) || defined(USE_TENSORRT) -#if defined(USE_ROCM) -TEST(CApiTest, hip_graph_with_shape_nodes) { - const auto& api = Ort::GetApi(); - - // Enable hip graph in rocm provider option. - OrtROCMProviderOptions* rocm_options = nullptr; - ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); - std::unique_ptr - rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); - std::vector keys{"enable_hip_graph"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); - - Ort::SessionOptions session_options; - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( - static_cast(session_options), - rel_rocm_options.get()) == nullptr); - - // Successful loading of the ONNX model with shape nodes with hip graph feature enabled - Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); -} -#endif // defined(USE_ROCM) - #if defined(USE_DML) TEST(CApiTest, dml_graph_with_shape_nodes) { const auto& api = Ort::GetApi(); - // Enable hip graph in rocm provider option. const OrtDmlApi* ort_dml_api; Ort::SessionOptions session_options; session_options.AddConfigEntry("ep.dml.enable_graph_capture", "1"); @@ -2861,7 +2752,7 @@ TEST(CApiTest, dml_graph_with_shape_nodes) { #endif // REDUCED_OPS_BUILD -#endif // defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) +#endif // defined(USE_CUDA) || defined(USE_TENSORRT) TEST(CApiTest, create_tensor) { const char* s[] = {"abc", "kmp"}; diff --git a/onnxruntime/test/testdata/node_output_not_used.onnx b/onnxruntime/test/testdata/node_output_not_used.onnx new file mode 100644 index 0000000000000..e2726182fddc2 Binary files /dev/null and b/onnxruntime/test/testdata/node_output_not_used.onnx differ diff --git a/onnxruntime/test/testdata/node_output_not_used.py b/onnxruntime/test/testdata/node_output_not_used.py new file mode 100644 index 0000000000000..d36d5e9cfd2f8 --- /dev/null +++ b/onnxruntime/test/testdata/node_output_not_used.py @@ -0,0 +1,43 @@ +import onnx +from onnx import TensorProto, helper + + +def create_model_with_node_output_not_used(model_path): + # Create graph + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2]) + w = helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + + # Dropout node (two outputs) + dropout_node = helper.make_node( + "Dropout", + inputs=["X"], + outputs=["dropout_out", "dropout_mask"], + name="DropoutNode", + ) + + # MatMul node + matmul_node = helper.make_node( + "MatMul", + inputs=["dropout_out", "W"], + outputs=["Y"], + name="MatMulNode", + ) + + graph = helper.make_graph( + nodes=[dropout_node, matmul_node], + name="DropoutMatMulGraph", + inputs=[x, w], + outputs=[y], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)]) + + onnx.checker.check_model(model) + onnx.save(model, model_path) + + print(f"Model saved to: {model_path}") + + +if __name__ == "__main__": + create_model_with_node_output_not_used("node_output_not_used.onnx") diff --git a/onnxruntime/test/testdata/squeeze_mul_relu.onnx b/onnxruntime/test/testdata/squeeze_mul_relu.onnx new file mode 100644 index 0000000000000..5b18047e67d6a Binary files /dev/null and b/onnxruntime/test/testdata/squeeze_mul_relu.onnx differ diff --git a/onnxruntime/test/testdata/squeeze_mul_relu.py b/onnxruntime/test/testdata/squeeze_mul_relu.py new file mode 100644 index 0000000000000..9999383fb02de --- /dev/null +++ b/onnxruntime/test/testdata/squeeze_mul_relu.py @@ -0,0 +1,50 @@ +from onnx import TensorProto, checker, helper, save, shape_inference + +# A --> Squeeze --> Mul --> Relu --> Mul(2x) --> C +# ^ +# | +# B ----------------+ +graph_proto = helper.make_graph( + nodes=[ + helper.make_node( + "Squeeze", + inputs=["A"], + outputs=["squeeze0_output"], + name="squeeze_0", + ), + helper.make_node( + "Mul", + inputs=["squeeze0_output", "B"], + outputs=["mul0_output"], + name="mul_0", + ), + helper.make_node( + "Relu", + inputs=["mul0_output"], + outputs=["relu0_output"], + name="relu_0", + ), + helper.make_node( + "Mul", + inputs=["relu0_output", "Const2"], + outputs=["C"], + name="mul_1", + ), + ], + name="Main_graph", + inputs=[ + helper.make_tensor_value_info("A", TensorProto.FLOAT, [3, 1, 2]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 2]), + ], + outputs=[ + helper.make_tensor_value_info("C", TensorProto.FLOAT, [3, 2]), + ], + initializer=[ + helper.make_tensor("Const2", TensorProto.FLOAT, [3, 2], [2.0] * 6), + ], +) + +model = helper.make_model(graph_proto) +model = shape_inference.infer_shapes(model) +checker.check_model(model, True) +save(model, "squeeze_mul_relu.onnx") diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.onnx b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.onnx new file mode 100644 index 0000000000000..e18aa31e414e5 Binary files /dev/null and b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.onnx differ diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py new file mode 100644 index 0000000000000..6537a3cd357c3 --- /dev/null +++ b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py @@ -0,0 +1,53 @@ +import onnx +from onnx import TensorProto, helper + +# 1. Define graph input with symbolic shape ['batch', 3, 'width', 'height'] +input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", 3, "width", "height"]) + +# 2. Define intermediate and output tensors +shape_out = helper.make_tensor_value_info("shape_out", TensorProto.INT64, [4]) # Shape output +reshape_a_out = helper.make_tensor_value_info("reshape_a_out", TensorProto.FLOAT, None) +output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, None) + +# 3. Create the initializer for Reshape A's 'shape' input: [0, 32, -1] +shape_initializer = helper.make_tensor( + name="reshape_a_shape", + data_type=TensorProto.INT64, + dims=[3], + vals=[0, 32, -1], +) + +# 4. Create nodes: +# Shape node +shape_node = helper.make_node("Shape", inputs=["input"], outputs=["shape_out"], name="ShapeNode") + +# Reshape A node: takes input + constant shape +reshape_a_node = helper.make_node( + "Reshape", inputs=["input", "reshape_a_shape"], outputs=["reshape_a_out"], name="ReshapeA" +) + +# Reshape B node: takes Shape + ReshapeA outputs, outputs final output +reshape_b_node = helper.make_node("Reshape", inputs=["reshape_a_out", "shape_out"], outputs=["output"], name="ReshapeB") + +# 5. Assemble the graph +graph = helper.make_graph( + nodes=[shape_node, reshape_a_node, reshape_b_node], + name="Shape_Reshape_Model", + inputs=[input_tensor], + outputs=[output_tensor], + initializer=[shape_initializer], + value_info=[shape_out, reshape_a_out], +) + +# 6. Define the model (set IR and opset) +model = helper.make_model( + graph, + opset_imports=[helper.make_operatorsetid("", 18)], + producer_name="onnx-example-generator", +) +model.ir_version = onnx.IR_VERSION + +# 7. Save the model +onnx.save(model, "test_shape_data_propagation_with_shape_related_nodes.onnx") + +print("Model saved to test_shape_data_propagation_with_shape_related_nodes.onnx") diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx new file mode 100644 index 0000000000000..ff41075ff64cc Binary files /dev/null and b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx differ diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py new file mode 100644 index 0000000000000..7cfbcca8d4d03 --- /dev/null +++ b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py @@ -0,0 +1,59 @@ +import onnx +from onnx import TensorProto, helper + +# === Graph input/output === +input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", 3, "width", "height"]) +output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", 3, "width*height"]) + +# === Initializers === +B = helper.make_tensor("B", TensorProto.FLOAT, [], [1.0]) + +# Gather indices +g0_idx = helper.make_tensor("g0_idx", TensorProto.INT64, [], [0]) +g1_idx = helper.make_tensor("g1_idx", TensorProto.INT64, [], [1]) +g2_idx = helper.make_tensor("g2_idx", TensorProto.INT64, [], [2]) +g3_idx = helper.make_tensor("g3_idx", TensorProto.INT64, [], [3]) + +# Unsqueeze axes tensors +axes_unsq0 = helper.make_tensor("axes_unsq0", TensorProto.INT64, [1], [0]) +axes_unsq1 = helper.make_tensor("axes_unsq1", TensorProto.INT64, [1], [0]) +axes_unsq2 = helper.make_tensor("axes_unsq2", TensorProto.INT64, [1], [0]) + +# === Nodes === +div = helper.make_node("Div", ["input", "B"], ["div_out"]) + +# Two Shape nodes from Div +shape_left = helper.make_node("Shape", ["div_out"], ["shape_left_out"]) +shape_right = helper.make_node("Shape", ["div_out"], ["shape_right_out"]) + +# Left Shape path +gather0 = helper.make_node("Gather", ["shape_left_out", "g0_idx"], ["g0_out"]) +gather1 = helper.make_node("Gather", ["shape_left_out", "g1_idx"], ["g1_out"]) +unsq0 = helper.make_node("Unsqueeze", ["g0_out", "axes_unsq0"], ["u0_out"]) +unsq1 = helper.make_node("Unsqueeze", ["g1_out", "axes_unsq1"], ["u1_out"]) + +# Right Shape path +gather2 = helper.make_node("Gather", ["shape_right_out", "g2_idx"], ["g2_out"]) +gather3 = helper.make_node("Gather", ["shape_right_out", "g3_idx"], ["g3_out"]) +mul = helper.make_node("Mul", ["g2_out", "g3_out"], ["mul_out"]) +unsq2 = helper.make_node("Unsqueeze", ["mul_out", "axes_unsq2"], ["u2_out"]) + +# Combine +concat = helper.make_node("Concat", ["u0_out", "u1_out", "u2_out"], ["concat_out"], axis=0) +reshape = helper.make_node("Reshape", ["div_out", "concat_out"], ["output"]) + +# === Graph === +graph = helper.make_graph( + [div, shape_left, shape_right, gather0, gather1, gather2, gather3, mul, unsq0, unsq1, unsq2, concat, reshape], + "Div_Shape_Gather_Concat_Reshape", + [input_tensor], + [output_tensor], + initializer=[B, g0_idx, g1_idx, g2_idx, g3_idx, axes_unsq0, axes_unsq1, axes_unsq2], +) + +# === Model === +model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)], producer_name="onnx-example") +onnx.checker.check_model(model) +onnx.save(model, "test_shape_data_propagation_with_shape_related_nodes_v2.onnx") + +print("✅ Model saved as test_shape_data_propagation_with_shape_related_nodes_v2.onnx") diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v3.onnx b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v3.onnx new file mode 100644 index 0000000000000..2889ec34afd41 Binary files /dev/null and b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v3.onnx differ diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v3.py b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v3.py new file mode 100644 index 0000000000000..75bbbe7b4557c --- /dev/null +++ b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v3.py @@ -0,0 +1,109 @@ +import onnx +from onnx import TensorProto, helper + +# === Graph input/output === +input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", 3, "width", "height"]) +output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", 3, "width*height"]) + +# === Initializers === +B = helper.make_tensor("B", TensorProto.FLOAT, [], [1.0]) + +# Gather indices +g0_idx = helper.make_tensor("g0_idx", TensorProto.INT64, [], [0]) +g1_idx = helper.make_tensor("g1_idx", TensorProto.INT64, [], [1]) +g2_idx = helper.make_tensor("g2_idx", TensorProto.INT64, [], [2]) +g3_idx = helper.make_tensor("g3_idx", TensorProto.INT64, [], [3]) + +# Unsqueeze axes tensors +axes_unsq0 = helper.make_tensor("axes_unsq0", TensorProto.INT64, [1], [0]) +axes_unsq1 = helper.make_tensor("axes_unsq1", TensorProto.INT64, [1], [0]) +axes_unsq2 = helper.make_tensor("axes_unsq2", TensorProto.INT64, [1], [0]) + +# === Nodes === +div = helper.make_node("Div", ["input", "B"], ["div_out"]) + +# Two Shape nodes from Div +shape_left = helper.make_node("Shape", ["div_out"], ["shape_left_out"]) +shape_right = helper.make_node("Shape", ["div_out"], ["shape_right_out"]) + +# Left Shape path +gather0 = helper.make_node("Gather", ["shape_left_out", "g0_idx"], ["g0_out"]) +gather1 = helper.make_node("Gather", ["shape_left_out", "g1_idx"], ["g1_out"]) +unsq0 = helper.make_node("Unsqueeze", ["g0_out", "axes_unsq0"], ["u0_out"]) +unsq1 = helper.make_node("Unsqueeze", ["g1_out", "axes_unsq1"], ["u1_out"]) + +# Right Shape path +gather2 = helper.make_node("Gather", ["shape_right_out", "g2_idx"], ["g2_out"]) +gather3 = helper.make_node("Gather", ["shape_right_out", "g3_idx"], ["g3_out"]) +mul = helper.make_node("Mul", ["g2_out", "g3_out"], ["mul_out"]) +unsq2 = helper.make_node("Unsqueeze", ["mul_out", "axes_unsq2"], ["u2_out"]) + +# Combine +concat = helper.make_node("Concat", ["u0_out", "u1_out", "u2_out"], ["concat_out"], axis=0) + +# Axes initializers +axes_u1 = helper.make_tensor("axes_u1", TensorProto.INT64, [1], [1]) +axes_u2 = helper.make_tensor("axes_u2", TensorProto.INT64, [1], [1]) +axes_s1 = helper.make_tensor("axes_s1", TensorProto.INT64, [1], [1]) +axes_s2 = helper.make_tensor("axes_s2", TensorProto.INT64, [1], [1]) + +# First Unsqueeze +unsqueeze1 = helper.make_node("Unsqueeze", inputs=["concat_out", "axes_u1"], outputs=["u1"], name="Unsqueeze_1") + +# Second Unsqueeze +unsqueeze2 = helper.make_node("Unsqueeze", inputs=["u1", "axes_u2"], outputs=["u2"], name="Unsqueeze_2") + +# First Squeeze +squeeze1 = helper.make_node("Squeeze", inputs=["u2", "axes_s1"], outputs=["s1"], name="Squeeze_1") + +# Second Squeeze +squeeze2 = helper.make_node("Squeeze", inputs=["s1", "axes_s2"], outputs=["squeeze_output"], name="Squeeze_2") + +reshape = helper.make_node("Reshape", ["div_out", "squeeze_output"], ["output"]) + +# === Graph === +graph = helper.make_graph( + [ + div, + shape_left, + shape_right, + gather0, + gather1, + gather2, + gather3, + mul, + unsq0, + unsq1, + unsq2, + concat, + unsqueeze1, + unsqueeze2, + squeeze1, + squeeze2, + reshape, + ], + "Div_Shape_Gather_Concat_Reshape", + [input_tensor], + [output_tensor], + initializer=[ + B, + g0_idx, + g1_idx, + g2_idx, + g3_idx, + axes_unsq0, + axes_unsq1, + axes_unsq2, + axes_u1, + axes_u2, + axes_s1, + axes_s2, + ], +) + +# === Model === +model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)], producer_name="onnx-example") +onnx.checker.check_model(model) +onnx.save(model, "test_shape_data_propagation_with_shape_related_nodes_v3.onnx") + +print("✅ Model saved as test_shape_data_propagation_with_shape_related_nodes_v3.onnx") diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v4.onnx b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v4.onnx new file mode 100644 index 0000000000000..d13f317b3a1c8 Binary files /dev/null and b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v4.onnx differ diff --git a/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx new file mode 100644 index 0000000000000..340c3d420d574 Binary files /dev/null and b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx differ diff --git a/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.py b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.py new file mode 100644 index 0000000000000..232abb2ed9163 --- /dev/null +++ b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.py @@ -0,0 +1,78 @@ +import onnx +from onnx import TensorProto, helper + + +def create_model_with_topk_graph_output(model_path): + # ====================== + # ---- Inputs ---- + # ====================== + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["N"]) + + # ====================== + # ---- Initializers ---- + # ====================== + k = helper.make_tensor("K", TensorProto.INT64, dims=[1], vals=[300]) + zero = helper.make_tensor("zero", TensorProto.INT64, dims=[], vals=[0]) + twenty_six = helper.make_tensor("twenty_six", TensorProto.INT64, dims=[], vals=[26]) + + # ====================== + # ---- Nodes ---- + # ====================== + topk_node = helper.make_node( + "TopK", + inputs=["input", "K"], + outputs=["scores", "topk_indices"], + name="TopK", + ) + + less_node = helper.make_node( + "Less", + inputs=["topk_indices", "zero"], + outputs=["Less_output_0"], + name="Less", + ) + + div_node = helper.make_node( + "Div", + inputs=["topk_indices", "twenty_six"], + outputs=["Div_17_output_0"], + name="Div", + ) + + mod_node = helper.make_node( + "Mod", + inputs=["topk_indices", "twenty_six"], + outputs=["labels"], + name="Mod", + ) + + # ========================= + # ---- Graph Outputs ---- + # ========================= + scores_out = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["K"]) + less_out = helper.make_tensor_value_info("Less_output_0", TensorProto.BOOL, ["K"]) + div_out = helper.make_tensor_value_info("Div_17_output_0", TensorProto.INT64, ["K"]) + labels_out = helper.make_tensor_value_info("labels", TensorProto.INT64, ["K"]) + + # ====================== + # ---- Graph ---- + # ====================== + graph = helper.make_graph( + nodes=[topk_node, less_node, div_node, mod_node], + name="TopKGraph", + inputs=[input_tensor], + outputs=[scores_out, less_out, div_out, labels_out], + initializer=[k, zero, twenty_six], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)]) + + # Validate + Save + onnx.checker.check_model(model) + onnx.save(model, model_path) + + print(f"Model saved to: {model_path}") + + +if __name__ == "__main__": + create_model_with_topk_graph_output("topk_and_multiple_graph_outputs.onnx") diff --git a/onnxruntime/test/unittest_util/base_tester.cc b/onnxruntime/test/unittest_util/base_tester.cc index 4d640e0f5e33d..0887c200e49f0 100644 --- a/onnxruntime/test/unittest_util/base_tester.cc +++ b/onnxruntime/test/unittest_util/base_tester.cc @@ -666,7 +666,6 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, kArmNNExecutionProvider, kNnapiExecutionProvider, kVSINPUExecutionProvider, - kRocmExecutionProvider, kCoreMLExecutionProvider, kCoreMLExecutionProviderMLProgram, kQnnExecutionProvider, @@ -732,8 +731,6 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, execution_provider = DefaultAclExecutionProvider(); else if (provider_type == onnxruntime::kArmNNExecutionProvider) execution_provider = DefaultArmNNExecutionProvider(); - else if (provider_type == onnxruntime::kRocmExecutionProvider) - execution_provider = DefaultRocmExecutionProvider(); else if (provider_type == onnxruntime::kCoreMLExecutionProvider) execution_provider = DefaultCoreMLExecutionProvider(); else if (provider_type == kCoreMLExecutionProviderMLProgram) @@ -771,27 +768,6 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, number_of_pre_packed_weights_counter, number_of_shared_pre_packed_weights_counter); - // Run Models with subscribed run_options->config_options - if (ctx_.run_options != nullptr && - ctx_.run_options->config_options.GetConfigEntry(kOpTesterRunOptionsConfigTestTunableOp) == "true") { - std::vector> execution_providers; - if (provider_type == onnxruntime::kRocmExecutionProvider) { - execution_providers.emplace_back(DefaultRocmExecutionProvider(/*test_tunable_op=*/true)); - } - - if (!execution_providers.empty()) { - ExecuteModelForEps( - std::move(execution_providers), model, ctx_.session_options, - ctx_.expect_result, ctx_.expected_failure_string, - ctx_.run_options, feeds, output_names, - &custom_session_registries_, - /*assign_ep_for_nodes=*/true, - allow_released_onnx_opset_only, - number_of_pre_packed_weights_counter, - number_of_shared_pre_packed_weights_counter); - } - } - has_run = true; cur_provider = "not set"; } diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index cea3feeb927af..edccc314b75f0 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -329,9 +329,5 @@ std::unique_ptr DefaultDmlExecutionProvider() { return nullptr; } -std::unique_ptr DefaultRocmExecutionProvider(bool) { - return nullptr; -} - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index ab3136c0b7b33..fb7b168f5e158 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -23,7 +23,6 @@ std::shared_ptr CreateExecutionProviderFactory_Nnapi( uint32_t flags, const optional& partitioning_stop_ops_list); std::shared_ptr CreateExecutionProviderFactory_VSINPU(); std::shared_ptr CreateExecutionProviderFactory_Rknpu(); -std::shared_ptr CreateExecutionProviderFactory_Rocm(const OrtROCMProviderOptions* provider_options); std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params); std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* params); std::shared_ptr CreateExecutionProviderFactory_Cann(const OrtCANNProviderOptions* provider_options); @@ -57,7 +56,6 @@ std::unique_ptr DefaultVSINPUExecutionProvider(); std::unique_ptr DefaultRknpuExecutionProvider(); std::unique_ptr DefaultAclExecutionProvider(bool enable_fast_math = false); std::unique_ptr DefaultArmNNExecutionProvider(bool enable_arena = true); -std::unique_ptr DefaultRocmExecutionProvider(bool test_tunable_op = false); std::unique_ptr DefaultCoreMLExecutionProvider(bool use_mlprogram = false); std::unique_ptr DefaultSnpeExecutionProvider(); std::unique_ptr DefaultQnnExecutionProvider(); diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index a2b30eac03514..24703fd92dac5 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -130,7 +130,7 @@ std::vector> GeneratePreTrainingTransformers( } transformers.emplace_back(std::make_unique(compatible_eps, level, true)); -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) transformers.emplace_back(std::make_unique(compatible_eps, true /* skip_device_check*/)); #else @@ -145,7 +145,7 @@ std::vector> GeneratePreTrainingTransformers( // Quantization Aware Training. So, replace QDQ nodes with FakeQuant. transformers.emplace_back(std::make_unique(compatible_eps)); -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) // We are supposed to use the execution provider as an indicator, // but here we don't have access to the registered EP at this point // as the session is not initialized yet. So using macro for now. @@ -180,8 +180,7 @@ std::vector> GeneratePreTrainingTransformers( config.number_recompute_layers, compatible_eps)); } if (config.propagate_cast_ops_config.level >= 0) { - const InlinedHashSet cuda_execution_provider = {onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider}; + const InlinedHashSet cuda_execution_provider = {onnxruntime::kCudaExecutionProvider}; transformers.emplace_back(std::make_unique(config.propagate_cast_ops_config.strategy, static_cast(config.propagate_cast_ops_config.level), config.propagate_cast_ops_config.allow, @@ -194,8 +193,8 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps, config.print_input_density)); -#if defined(USE_CUDA) || defined(USE_ROCM) - // Put this under CUDA/ROCM guard as it depends on PadAndUnflatten CUDA/ROCM kernel. +#if defined(USE_CUDA) + // Put this under CUDA guard as it depends on PadAndUnflatten CUDA kernel. // Once we have a CPU kernel for PadAndUnflatten, we can remove the guard. transformers.emplace_back(std::make_unique(compatible_eps, config.print_input_density)); @@ -261,17 +260,16 @@ InlinedVector> GenerateTransformers( switch (level) { case TransformerLevel::Level1: { InlinedHashSet l1_execution_providers = {}; - InlinedHashSet cuda_rocm_execution_providers = {onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider}; + InlinedHashSet cuda_execution_providers = {onnxruntime::kCudaExecutionProvider}; // TODO hack - constant folding currently doesn't work after mixed precision transformation so it's disabled for now // ORT uses CPU kernels to evaluate constant values but some of them don't support fp16 // transformers.emplace_back(std::make_unique(l1_execution_providers)); transformers.emplace_back(std::make_unique(l1_execution_providers)); transformers.emplace_back(std::make_unique(free_dimension_overrides)); - transformers.emplace_back(std::make_unique(cuda_rocm_execution_providers)); - transformers.emplace_back(std::make_unique(cuda_rocm_execution_providers)); - transformers.emplace_back(std::make_unique(cuda_rocm_execution_providers)); + transformers.emplace_back(std::make_unique(cuda_execution_providers)); + transformers.emplace_back(std::make_unique(cuda_execution_providers)); + transformers.emplace_back(std::make_unique(cuda_execution_providers)); transformers.emplace_back(std::make_unique(l1_execution_providers)); InlinedHashSet excluded_initializers(weights_to_train.begin(), weights_to_train.end()); transformers.emplace_back(std::make_unique(l1_execution_providers, excluded_initializers)); diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index c4c7a98ba116a..772c1ef5d856a 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -26,11 +26,6 @@ namespace onnxruntime { std::unique_ptr CreateCUDAPinnedAllocator(const char* name); } // namespace onnxruntime #endif -#ifdef USE_ROCM -namespace onnxruntime { -std::unique_ptr CreateROCMPinnedAllocator(const char* name); -} // namespace onnxruntime -#endif using namespace onnxruntime; using namespace onnxruntime::common; @@ -638,22 +633,6 @@ void setup_training_params(BertParameters& params) { } #endif -#ifdef USE_ROCM - { - OrtROCMProviderOptions info; - info.device_id = gsl::narrow(MPIContext::GetInstance().GetLocalRank()); - info.do_copy_in_default_stream = true; - - if (params.gpu_mem_limit_in_gb > 0) { - info.gpu_mem_limit = gsl::narrow(params.gpu_mem_limit_in_gb * 1024 * 1024 * 1024); - } - info.miopen_conv_exhaustive_search = true; // true, exhaustive search (slow) - - params.providers.emplace(kRocmExecutionProvider, RocmProviderFactoryCreator::Create(&info)); - params.input_allocator = CreateROCMPinnedAllocator(HIP_PINNED); - } -#endif - params.loss_func_info = LossFunctionInfo(OpDef("BertLoss", kOnnxDomain), "total_loss", {/*prediction_masked_lm*/ "output1", diff --git a/orttraining/orttraining/models/gpt2/main.cc b/orttraining/orttraining/models/gpt2/main.cc index 165d69fb1378a..b1e6222d1fb19 100644 --- a/orttraining/orttraining/models/gpt2/main.cc +++ b/orttraining/orttraining/models/gpt2/main.cc @@ -26,11 +26,6 @@ namespace onnxruntime { std::unique_ptr CreateCUDAPinnedAllocator(const char* name); } // namespace onnxruntime #endif -#ifdef USE_ROCM -namespace onnxruntime { -std::unique_ptr CreateROCMPinnedAllocator(const char* name); -} // namespace onnxruntime -#endif using namespace onnxruntime; using namespace onnxruntime::common; @@ -368,16 +363,6 @@ void setup_training_params(GPT2Parameters& params) { } #endif -#ifdef USE_ROCM - { - OrtROCMProviderOptions info; - info.device_id = gsl::narrow(MPIContext::GetInstance().GetLocalRank()); - info.do_copy_in_default_stream = true; - params.providers.emplace(kRocmExecutionProvider, RocmProviderFactoryCreator::Create(&info)); - params.input_allocator = CreateROCMPinnedAllocator(HIP_PINNED); - } -#endif - params.use_nccl = true; params.error_function = [params](const std::vector& /*feed_names*/, diff --git a/orttraining/orttraining/models/runner/training_runner.h b/orttraining/orttraining/models/runner/training_runner.h index f870f445ead91..da2ebb67cb52e 100644 --- a/orttraining/orttraining/models/runner/training_runner.h +++ b/orttraining/orttraining/models/runner/training_runner.h @@ -138,8 +138,7 @@ class TrainingRunner { } bool UseCuda() const { - return providers.find(kCudaExecutionProvider) != providers.end() || - providers.find(kRocmExecutionProvider) != providers.end(); + return providers.find(kCudaExecutionProvider) != providers.end(); } AdasumReductionType GetAdasumReductionType() const { diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 1aa8d2090b511..3d611a0881fdf 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -33,11 +33,6 @@ const CUDAExecutionProviderInfo GetCudaExecutionProviderInfo(ProviderInfo_CUDA* const ProviderOptionsMap& provider_options_map); #endif -#ifdef USE_ROCM -const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* rocm_provider_info, - const ProviderOptionsMap& provider_options_map); -#endif - void addGlobalMethods(py::module& m); void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn); void addObjectMethodsForTraining(py::module& m); @@ -77,7 +72,7 @@ bool GetDynamicExecutionProviderHash( bool GetProviderInstanceHash(const std::string& type, const ProviderOptionsMap& provider_options_map, size_t& hash) { - // for built-in execution provider, currently only cpu / cuda / rocm support hash. + // for built-in execution provider, currently only cpu / cuda support hash. if (type == kCpuExecutionProvider) { // for CPU, only 1 instance hash = 0; @@ -90,15 +85,6 @@ bool GetProviderInstanceHash(const std::string& type, hash = std::hash{}(info); return true; } -#endif - } else if (type == kRocmExecutionProvider) { -#ifdef USE_ROCM - if (auto* rocm_provider_info = TryGetProviderInfo_ROCM()) { - const ROCMExecutionProviderInfo info = GetRocmExecutionProviderInfo(rocm_provider_info, - provider_options_map); - hash = std::hash{}(info); - return true; - } #endif } else { const auto it = provider_options_map.find(type); diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index 4bc470c633437..4c27fd923c96b 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -138,7 +138,6 @@ def _checkpoint( ORTMODULE_IS_DETERMINISTIC = torch.are_deterministic_algorithms_enabled() ONNXRUNTIME_CUDA_VERSION = ort_info.cuda_version if hasattr(ort_info, "cuda_version") else None -ONNXRUNTIME_ROCM_VERSION = ort_info.rocm_version if hasattr(ort_info, "rocm_version") else None # The first value indicates whether the code is in ONNX export context. # The export context here include the full export process, including prepare export input/output information, diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index d0447f1a96b17..0b7b338c27344 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -10,7 +10,6 @@ import onnx import torch -from torch.utils.cpp_extension import ROCM_HOME import onnxruntime from onnxruntime.capi import _pybind_state as C @@ -91,8 +90,6 @@ def __init__( # To be instantiated in the concrete implementation of GraphExecutionManager self._export_mode = export_mode - self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None) - # WIP feature to enable caching in Gradient accumulation scenario. self._gradient_accumulation_manager = GradientAccumulationManager() @@ -194,23 +191,22 @@ def _get_session_config(self): provider_options = None if self._device.type == "cuda": # Configure the InferenceSessions to use the specific GPU on which the model is placed. - providers = ["ROCMExecutionProvider"] if self.is_rocm_pytorch else ["CUDAExecutionProvider"] + providers = ["CUDAExecutionProvider"] providers.append("CPUExecutionProvider") provider_option_map = {"device_id": str(self._device.index)} - if not self.is_rocm_pytorch: - # Set Conv algo search mode to HEURISTIC by default, which is the same as PyTorch's default setting. - provider_option_map["cudnn_conv_algo_search"] = self._runtime_options.conv_algo_search - provider_option_map["cudnn_conv_use_max_workspace"] = "1" - provider_option_map["cudnn_conv1d_pad_to_nc1d"] = "1" - if self._runtime_options.enable_tuning: - provider_option_map["tunable_op_enable"] = "1" - provider_option_map["tunable_op_tuning_enable"] = "1" - if self._runtime_options.max_tuning_duration_ms: - provider_option_map["tunable_op_max_tuning_duration_ms"] = str( - self._runtime_options.max_tuning_duration_ms - ) - elif self._runtime_options.tuning_results_path: - provider_option_map["tunable_op_enable"] = "1" + # Set Conv algo search mode to HEURISTIC by default, which is the same as PyTorch's default setting. + provider_option_map["cudnn_conv_algo_search"] = self._runtime_options.conv_algo_search + provider_option_map["cudnn_conv_use_max_workspace"] = "1" + provider_option_map["cudnn_conv1d_pad_to_nc1d"] = "1" + if self._runtime_options.enable_tuning: + provider_option_map["tunable_op_enable"] = "1" + provider_option_map["tunable_op_tuning_enable"] = "1" + if self._runtime_options.max_tuning_duration_ms: + provider_option_map["tunable_op_max_tuning_duration_ms"] = str( + self._runtime_options.max_tuning_duration_ms + ) + elif self._runtime_options.tuning_results_path: + provider_option_map["tunable_op_enable"] = "1" if self._runtime_options.use_external_gpu_allocator: provider_option_map["gpu_external_alloc"] = str(self._torch_alloc) provider_option_map["gpu_external_free"] = str(self._torch_free) diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index b4303587e69e6..3fa3e1cdaf461 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -342,8 +342,6 @@ def _build_graph(self, graph_transformer_config): # Apply registered graph transformers to the optimized model device_type = self._device.type - if device_type == "cuda" and self.is_rocm_pytorch: - device_type = "rocm" GraphOptimizerRegistry.optimize_all( type(self._flattened_module._original_module).__name__, device_type, self._onnx_models.optimized_model.graph ) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/__init__.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/__init__.py index e6b1f0fb8b391..561ebec1e7cc4 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/__init__.py @@ -19,7 +19,6 @@ The following environment variables are available for the extensions setup.py - ORTMODULE_TORCH_CPP_DIR: ORTModule's internal - - ONNXRUNTIME_ROCM_VERSION: ROCM version used to build ONNX Runtime package - ONNXRUNTIME_CUDA_VERSION: CUDA version used to build ONNX Runtime package - ONNXRUNTIME_FORCE_CUDA: Force CUDA extensions to be used when it is not available to build ONNX Runtime package diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py index 6b028d8f05e11..ad3c298187ca3 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py @@ -18,12 +18,10 @@ os.path.join(os.path.dirname(__file__), "multi_tensor_l2norm_kernel.cu"), ] -use_rocm = bool(os.environ["ONNXRUNTIME_ROCM_VERSION"]) extra_compile_args = {"cxx": ["-O3"]} -if not use_rocm: - nvcc_extra_args = os.environ.get("ONNXRUNTIME_CUDA_NVCC_EXTRA_ARGS", "") - if nvcc_extra_args: - extra_compile_args.update({"nvcc": nvcc_extra_args.split(",")}) +nvcc_extra_args = os.environ.get("ONNXRUNTIME_CUDA_NVCC_EXTRA_ARGS", "") +if nvcc_extra_args: + extra_compile_args.update({"nvcc": nvcc_extra_args.split(",")}) setup( name="fused_ops", diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/setup.py index bdcb6daa233e6..68e15fb945361 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/setup.py @@ -11,9 +11,9 @@ from torch.utils import cpp_extension # TODO: Implement a cleaner way to auto-generate torch_gpu_allocator.cc -use_rocm = bool(os.environ["ONNXRUNTIME_ROCM_VERSION"]) -gpu_identifier = "hip" if use_rocm else "cuda" -gpu_allocator_header = "HIPCachingAllocator" if use_rocm else "CUDACachingAllocator" + +gpu_identifier = "cuda" +gpu_allocator_header = "CUDACachingAllocator" filename = os.path.join(os.path.dirname(__file__), "torch_gpu_allocator.cc") with fileinput.FileInput(filename, inplace=True) as file: for line in file: @@ -24,10 +24,9 @@ sys.stdout.write(line) extra_compile_args = {"cxx": ["-O3"]} -if not use_rocm: - nvcc_extra_args = os.environ.get("ONNXRUNTIME_CUDA_NVCC_EXTRA_ARGS", "") - if nvcc_extra_args: - extra_compile_args.update({"nvcc": nvcc_extra_args.split(",")}) +nvcc_extra_args = os.environ.get("ONNXRUNTIME_CUDA_NVCC_EXTRA_ARGS", "") +if nvcc_extra_args: + extra_compile_args.update({"nvcc": nvcc_extra_args.split(",")}) setup( name="torch_gpu_allocator", diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py index b26259b8abf94..d36f0d872f4df 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py @@ -53,7 +53,7 @@ def build_torch_cpp_extensions(): """Builds PyTorch CPP extensions and returns metadata.""" # Run this from within onnxruntime package folder is_gpu_available = (torch.version.cuda is not None or torch.version.hip is not None) and ( - ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or ortmodule.ONNXRUNTIME_ROCM_VERSION is not None + ortmodule.ONNXRUNTIME_CUDA_VERSION is not None ) # Docker build don't have CUDA support, but Torch C++ extensions with CUDA may be forced @@ -61,26 +61,23 @@ def build_torch_cpp_extensions(): os.chdir(ortmodule.ORTMODULE_TORCH_CPP_DIR) - # Extensions might leverage CUDA/ROCM versions internally + # Extensions might leverage CUDA versions internally os.environ["ONNXRUNTIME_CUDA_VERSION"] = ( ortmodule.ONNXRUNTIME_CUDA_VERSION if ortmodule.ONNXRUNTIME_CUDA_VERSION is not None else "" ) - os.environ["ONNXRUNTIME_ROCM_VERSION"] = ( - ortmodule.ONNXRUNTIME_ROCM_VERSION if ortmodule.ONNXRUNTIME_ROCM_VERSION is not None else "" - ) if torch.version.cuda is not None and ortmodule.ONNXRUNTIME_CUDA_VERSION is not None: _get_cuda_extra_build_params() ############################################################################ - # Pytorch CPP Extensions that DO require CUDA/ROCM + # Pytorch CPP Extensions that DO require CUDA ############################################################################ if is_gpu_available or force_cuda: for ext_setup in _list_cuda_extensions(): _install_extension(ext_setup.split(os.sep)[-2], ext_setup, ortmodule.ORTMODULE_TORCH_CPP_DIR) ############################################################################ - # Pytorch CPP Extensions that DO NOT require CUDA/ROCM + # Pytorch CPP Extensions that DO NOT require CUDA ############################################################################ for ext_setup in _list_cpu_extensions(): _install_extension(ext_setup.split(os.sep)[-2], ext_setup, ortmodule.ORTMODULE_TORCH_CPP_DIR) @@ -98,7 +95,6 @@ def build_torch_cpp_extensions(): # Tear down os.environ.pop("ONNXRUNTIME_CUDA_VERSION") - os.environ.pop("ONNXRUNTIME_ROCM_VERSION") if __name__ == "__main__": diff --git a/orttraining/orttraining/test/gradient/gradient_checker.cc b/orttraining/orttraining/test/gradient/gradient_checker.cc index b30540ec68317..0a837254fa619 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.cc +++ b/orttraining/orttraining/test/gradient/gradient_checker.cc @@ -42,8 +42,6 @@ std::vector> GetExecutionProviders( result.emplace_back(DefaultCpuExecutionProvider()); } else if (entry->Type() == onnxruntime::kCudaExecutionProvider) { result.emplace_back(DefaultCudaExecutionProvider()); - } else if (entry->Type() == onnxruntime::kRocmExecutionProvider) { - result.emplace_back(DefaultRocmExecutionProvider()); } else if (entry->Type() == onnxruntime::kDnnlExecutionProvider) { result.emplace_back(DefaultDnnlExecutionProvider()); } else if (entry->Type() == onnxruntime::kTensorrtExecutionProvider) { @@ -59,9 +57,6 @@ std::vector> GetExecutionProviders( } #ifdef USE_CUDA result.emplace_back(DefaultCudaExecutionProvider()); -#endif -#ifdef USE_ROCM - result.emplace_back(DefaultRocmExecutionProvider()); #endif result.emplace_back(DefaultCpuExecutionProvider()); return result; diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index 58c173ed90277..b362abbdde3d6 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -59,7 +59,6 @@ void GradientOpTester::Run(int output_index_to_use_as_loss, static const std::string all_provider_types[] = { kCpuExecutionProvider, kCudaExecutionProvider, - kRocmExecutionProvider, kDnnlExecutionProvider, kTensorrtExecutionProvider, }; @@ -114,8 +113,7 @@ void GradientOpTester::Run(int output_index_to_use_as_loss, execution_provider = DefaultDnnlExecutionProvider(); else if (provider_type == onnxruntime::kTensorrtExecutionProvider) execution_provider = DefaultTensorrtExecutionProvider(); - else if (provider_type == onnxruntime::kRocmExecutionProvider) - execution_provider = DefaultRocmExecutionProvider(); + // skip if execution provider is disabled if (execution_provider == nullptr) continue; diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 2d0181b69413c..ae2b144bfdb69 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -2231,7 +2231,6 @@ TEST(GradientUtilsTest, InPlaceAccumulatorV2Overwrite) { } #if defined(USE_CUDA) -// TODO: Add rocm kernel defs TEST(GradientUtilsTest, InPlaceAccumulatorV2_GPU) { std::vector> test_dims{ {768}, @@ -2276,7 +2275,7 @@ TEST(GradientUtilsTest, InPlaceAccumulatorV2_Float16) { } #endif -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(GradientUtilsTest, InPlaceAccumulatorFloat16) { OpTester test("InPlaceAccumulator", 1, onnxruntime::kMSDomain); @@ -2294,7 +2293,7 @@ TEST(GradientUtilsTest, InPlaceAccumulatorFloat16) { // Didn't implement mixed precision InPlaceAccumulator in CPU test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider}); } -#endif // defined(USE_CUDA) || defined(USE_ROCM) +#endif // defined(USE_CUDA) TEST(GradientUtilsTest, ZeroGradientFloat32) { OpTester test("ZeroGradient", 1, onnxruntime::kMSDomain); @@ -2307,7 +2306,7 @@ TEST(GradientUtilsTest, ZeroGradientFloat32) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(GradientUtilsTest, ZeroGradientFloat16) { OpTester test("ZeroGradient", 1, onnxruntime::kMSDomain); @@ -2327,7 +2326,7 @@ TEST(GradientUtilsTest, ZeroGradientFloat16) { test.Run(); } -#endif // defined(USE_CUDA) || defined(USE_ROCM) +#endif // defined(USE_CUDA) TEST(GradientCheckerTest, WhereGrad) { float max_error; @@ -3019,7 +3018,6 @@ TEST(GradientCheckerTest, TriluGrad) { } } -// TODO (enable once found why it fails on ROCM) #if defined(USE_CUDA) TEST(GradientCheckerTest, PadAndUnflattenGrad) { float max_error; @@ -3035,8 +3033,6 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.emplace_back(DefaultRocmExecutionProvider()); #endif ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info}, @@ -3065,8 +3061,6 @@ TEST(GradientCheckerTest, ScaledSumGrad) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.emplace_back(DefaultRocmExecutionProvider()); #endif ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, y_info}, @@ -3097,8 +3091,6 @@ TEST(GradientCheckerTest, ScaledSumGrad) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.emplace_back(DefaultRocmExecutionProvider()); #endif ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, y_info, z_info}, @@ -3318,7 +3310,6 @@ TEST(GradientCheckerTest, ConvTransposeGrad) { ConvTransposeGradientCheckerTest(&execution_providers); } -// TODO: Enable test for ROCM TEST(GradientCheckerTest, ResizeGrad) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); diff --git a/orttraining/orttraining/test/gradient/optimizer_ops_test.cc b/orttraining/orttraining/test/gradient/optimizer_ops_test.cc index 18c1364f5d1f6..96830510e8ebd 100644 --- a/orttraining/orttraining/test/gradient/optimizer_ops_test.cc +++ b/orttraining/orttraining/test/gradient/optimizer_ops_test.cc @@ -254,7 +254,7 @@ TEST(OptimizerTest, AdamWeightDecayMode1WithBiasCorrection) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) float GetGradientL2Norm(const std::vector& gradient_vector) { float gradient_norm = 0.0f; diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index edabcb67aa586..1c2d71d4b4f90 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -16,7 +16,7 @@ #include "orttraining/training_ops/cpu/controlflow/event_pool.h" // TODO: move with PipelineBatchPlanner -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) #include "bert_toy_fetches.h" #endif @@ -382,7 +382,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithProfiler) { ASSERT_TRUE(count > 1); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) static void RunBertTrainingWithChecks( const SessionOptions& so, const PathString& backprop_model_file) { @@ -401,8 +401,6 @@ static void RunBertTrainingWithChecks( #ifdef USE_CUDA ASSERT_STATUS_OK(training_session->RegisterExecutionProvider(DefaultCudaExecutionProvider())); -#elif USE_ROCM - ASSERT_STATUS_OK(training_session->RegisterExecutionProvider(DefaultRocmExecutionProvider())); #endif ASSERT_STATUS_OK(training_session->Initialize()); @@ -579,7 +577,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_BertToy) { PathString backprop_model_file; ASSERT_STATUS_OK(BuildBackPropGraph(model_path, config, backprop_model_file)); -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) SessionOptions so; RunBertTrainingWithChecks(so, backprop_model_file); #endif diff --git a/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc b/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc index d7f0f2ce9b743..2f7d6532c9cb1 100644 --- a/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc @@ -559,8 +559,6 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_MlmBertE2E) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, -#elif USE_ROCM - onnxruntime::kRocmExecutionProvider, #endif }; diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7d5feb9742366..7b89b241b884e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -831,7 +831,6 @@ def run_step(model, x): ort_model._is_training() )._execution_agent._inference_session._provider_options - # cudnn_conv_algo_search is for CUDA only, so setting the system env will not affect the compute on ROCm. if "CUDAExecutionProvider" in provider_options: expected_conv_algo_search = "HEURISTIC" if conv_algo_search is None else conv_algo_search actual_conv_algo_search = provider_options["CUDAExecutionProvider"]["cudnn_conv_algo_search"] @@ -6173,11 +6172,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters(), strict=False): _test_helpers.assert_values_are_close(pt_param.grad, ort_param.grad, atol=1e-4, rtol=1e-5) - if os.getenv("ORTMODULE_ROCM_TEST", "0") == "1": - # For ROCm EP, the difference between ORT and PyTorch is larger than CUDA EP. - _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=2e-3, rtol=2e-4) - else: - _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-4) + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-4) training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model assert "FlattenAndUnpad" in [node.op_type for node in training_model.graph.node] @@ -6332,9 +6327,6 @@ def run_step(model, x): _test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad) -@pytest.mark.skipif( - os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm" -) @pytest.mark.parametrize("use_fp16", [False, True]) @pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"]) def test_conv_transpose_gradient(use_fp16, conv_algo_search): @@ -6404,9 +6396,6 @@ def run_step(model, x): del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] -@pytest.mark.skipif( - os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm" -) @pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"]) def test_conv_transpose_gradient_with_groups(conv_algo_search): class TransposedConv3DWithGroups(nn.Module): @@ -6450,9 +6439,6 @@ def run_step(model, x): del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] -@pytest.mark.skipif( - os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm" -) @pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"]) def test_conv_transpose_gradient_with_strides_padding_and_dilation(conv_algo_search): class ConvTransposeComplexModel(nn.Module): @@ -6644,8 +6630,6 @@ def run_step(model, attn_weight): assert to_value == pytorch_type_to_onnx_dtype(softmax_compute_type), "Cast to attribute is not as expected" -# TODO: fix the issue in rocm training, then enable the test. -@pytest.mark.skip(reason="This test is disabled due to its breaking rocm training cis.") def test_aten_conv_bf16(): class NeuralNetConv(torch.nn.Module): def __init__(self): @@ -6920,11 +6904,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): for ort_param1, ort_param2 in zip(ort_model1.parameters(), ort_model2.parameters(), strict=False): _test_helpers.assert_values_are_close(ort_param1.grad, ort_param2.grad, atol=1e-4, rtol=1e-5) - if os.getenv("ORTMODULE_ROCM_TEST", "0") == "1": - # For ROCm EP, the difference between ORT and PyTorch is larger than CUDA EP. - _test_helpers.assert_values_are_close(ort_prediction1, ort_prediction2, atol=2e-3, rtol=2e-4) - else: - _test_helpers.assert_values_are_close(ort_prediction1, ort_prediction2, atol=1e-3, rtol=1e-4) + _test_helpers.assert_values_are_close(ort_prediction1, ort_prediction2, atol=1e-3, rtol=1e-4) execution_mgr = ort_model2._torch_module._execution_manager._training_manager from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name # noqa: PLC0415 diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py index d977d96e82503..4ad615597b8b8 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py @@ -151,7 +151,7 @@ def test_softmax_bf16_large(self): raise unittest.SkipTest("Temporarily disabled pending investigation") if torch.version.cuda is None: - # Only run this test when CUDA is available, as on ROCm BF16 is not supported by MIOpen. + # Only run this test when CUDA is available. return class Model(torch.nn.Module): diff --git a/orttraining/orttraining/test/session/training_session_test.cc b/orttraining/orttraining/test/session/training_session_test.cc index a3f6d917a76b6..e91c714165f0f 100644 --- a/orttraining/orttraining/test/session/training_session_test.cc +++ b/orttraining/orttraining/test/session/training_session_test.cc @@ -57,7 +57,7 @@ TEST(TrainingSessionTest, LoadOptimState_FullPrecision_FP32Moments_Adam) { RunTrainingSessionLoadOptimTests(k_adam_optimizer_op_name, false, false); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(TrainingSessionTest, LoadOptimState_MixedPrecision_FP32Moments_Adam) { RunTrainingSessionLoadOptimTests(k_adam_optimizer_op_name, true, false); } diff --git a/orttraining/orttraining/test/session/training_session_test_utils.cc b/orttraining/orttraining/test/session/training_session_test_utils.cc index 868388d4b9a93..07f375fa747e9 100644 --- a/orttraining/orttraining/test/session/training_session_test_utils.cc +++ b/orttraining/orttraining/test/session/training_session_test_utils.cc @@ -100,7 +100,7 @@ void VerifyState(const DataTransferManager& data_transfer_mgr, const NameMLValMa const auto& e_state_it = expected_state.find(key); ORT_ENFORCE(e_state_it != expected_state.end()); auto& expected_tensor = e_state_it->second.Get(); -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) auto& actual_gpu_tensor = a_state_it.second.Get(); // Copying tensor to CPU when cuda is enabled. @@ -181,8 +181,6 @@ std::unique_ptr BuildAndRunTrainingSessionWithChecks( #ifdef USE_CUDA ORT_THROW_IF_ERROR(training_session->RegisterExecutionProvider(DefaultCudaExecutionProvider())); -#elif USE_ROCM - ORT_THROW_IF_ERROR(training_session->RegisterExecutionProvider(DefaultRocmExecutionProvider())); #endif ORT_THROW_IF_ERROR(training_session->Initialize()); diff --git a/orttraining/orttraining/test/session/training_session_test_utils.h b/orttraining/orttraining/test/session/training_session_test_utils.h index 4ba092b951081..866855ef8d747 100644 --- a/orttraining/orttraining/test/session/training_session_test_utils.h +++ b/orttraining/orttraining/test/session/training_session_test_utils.h @@ -18,8 +18,6 @@ #ifdef USE_CUDA #include "core/providers/cuda/cuda_execution_provider_info.h" -#elif USE_ROCM -#include "core/providers/rocm/rocm_execution_provider_info.h" #endif namespace onnxruntime { diff --git a/orttraining/orttraining/test/training_ops/cpu/math/isfinite_ops_test.cc b/orttraining/orttraining/test/training_ops/cpu/math/isfinite_ops_test.cc index d8d2cb8e83550..8caaeeda76936 100644 --- a/orttraining/orttraining/test/training_ops/cpu/math/isfinite_ops_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/math/isfinite_ops_test.cc @@ -10,7 +10,7 @@ namespace onnxruntime { namespace test { -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(IsFiniteTest, Float) { OpTester test("IsFinite", 1, kMSDomain); @@ -256,4 +256,4 @@ TEST(IsAllFiniteTest, MoreFalseFloatTensorLargeFloat16) { #endif } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc b/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc index 60c3ecbcce8ce..edbcb54fc5261 100644 --- a/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc @@ -11,7 +11,7 @@ namespace onnxruntime { namespace test { -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) void test_all_1d_true(size_t size) { std::unique_ptr p_data(new bool[size]); @@ -100,7 +100,7 @@ TEST_P(ReductionOpTest, ReduceAllL2) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST_P(ReductionOpTest, ReduceAllL2HalfHalf) { OpTester test("ReduceAllL2", 1, onnxruntime::kMSDomain, true); test.SetDeterminism(GetParam()); @@ -164,7 +164,7 @@ TEST_P(ReductionOpTest, ReduceAllL2HalfFloat) { } #endif -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST_P(ReductionOpTest, ReduceAllL2_BFloat16_BFloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -193,8 +193,6 @@ TEST_P(ReductionOpTest, ReduceAllL2_BFloat16_BFloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -226,8 +224,6 @@ TEST_P(ReductionOpTest, ReduceAllL2_BFloat16_Float) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -257,8 +253,6 @@ TEST_P(ReductionOpTest, ReduceAllL2_Float_BFloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -441,7 +435,7 @@ TEST(ReductionOpTest, ReduceSumTraining_neg_axis) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(ReductionOpTest, ReduceSumTrainingHalfHalf) { OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain); test.AddAttribute("keepdims", (int64_t)0); diff --git a/orttraining/orttraining/test/training_ops/cpu/tensor/gather_grad_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/tensor/gather_grad_op_test.cc index ced03d9df5c29..24c6d69efbcb4 100644 --- a/orttraining/orttraining/test/training_ops/cpu/tensor/gather_grad_op_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/tensor/gather_grad_op_test.cc @@ -96,7 +96,7 @@ void RunGatherGradTestWithRandomData( } } // namespace -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) // TODO: Currently this cannot pass CI, due to GPU architecture problem TEST(GatherOpTest, Gather_axis0_indices2d_half) { #ifdef USE_CUDA @@ -186,7 +186,7 @@ TEST(GatherGradOpTest, GatherFewDistinctIndices) { RunGatherGradTestWithRandomData(0, {2, 32}, {6, 128}, absolute_error); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) namespace { void RunGatherGradConsistentOutputTest( int64_t axis, diff --git a/orttraining/orttraining/test/training_ops/cuda/activations_test.cc b/orttraining/orttraining/test/training_ops/cuda/activations_test.cc index 3173610597f71..a974c2c8e2884 100644 --- a/orttraining/orttraining/test/training_ops/cuda/activations_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/activations_test.cc @@ -8,8 +8,6 @@ namespace test { #if USE_CUDA constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; -#elif USE_ROCM -constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider; #endif static void TestActivations(const std::vector& tensor_dim, diff --git a/orttraining/orttraining/test/training_ops/cuda/batch_norm_internal_test.cc b/orttraining/orttraining/test/training_ops/cuda/batch_norm_internal_test.cc index d842d4f1ea736..e99b700beeba5 100644 --- a/orttraining/orttraining/test/training_ops/cuda/batch_norm_internal_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/batch_norm_internal_test.cc @@ -14,7 +14,7 @@ namespace test { using namespace onnxruntime::test; -#if USE_CUDA || USE_ROCM +#if USE_CUDA static void TestBatchNormInternal(bool test_double = false, bool T_is_half = false, bool T1_is_half = false, bool T2_is_half = false, const std::vector& input_output_dims = {2, 2, 2, 2}) { @@ -137,11 +137,9 @@ TEST(CudaKernelTest, BNInternalBasic) { // float case TestBatchNormInternal(); } -#ifndef USE_ROCM // MIOpen does not support double type TEST(CudaKernelTest, BNInternalDouble) { // double case TestBatchNormInternal(true); } -#endif // ndef USE_ROCM TEST(CudaKernelTest, BNInternalHalf) { // half case TestBatchNormInternal(false, true, true, true); @@ -196,7 +194,7 @@ TEST(CudaKernelTest, BNInternal1DInput) { // float case, 1d input test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } -#endif // USE_CUDA || USE_ROCM +#endif // USE_CUDA } // namespace test } // namespace contrib diff --git a/orttraining/orttraining/test/training_ops/cuda/batch_scale_test.cc b/orttraining/orttraining/test/training_ops/cuda/batch_scale_test.cc index eb229b82caa55..4700589de8d57 100644 --- a/orttraining/orttraining/test/training_ops/cuda/batch_scale_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/batch_scale_test.cc @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" diff --git a/orttraining/orttraining/test/training_ops/cuda/bitmask_dropout_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/bitmask_dropout_grad_test.cc index 434d1804931b0..3ede24ac9b1fe 100644 --- a/orttraining/orttraining/test/training_ops/cuda/bitmask_dropout_grad_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/bitmask_dropout_grad_test.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) #include #include @@ -10,23 +10,14 @@ #include "test/providers/provider_test_utils.h" #include "test/common/tensor_op_test_utils.h" #include "test/util/include/default_providers.h" -#ifdef USE_ROCM -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#else #include "core/providers/cuda/shared_inc/cuda_utils.h" -#endif namespace onnxruntime { namespace contrib { namespace test { -#ifdef USE_ROCM -using onnxruntime::rocm::BitmaskElementType; -using onnxruntime::rocm::kNumBitsPerBitmaskElement; -#else using onnxruntime::cuda::BitmaskElementType; using onnxruntime::cuda::kNumBitsPerBitmaskElement; -#endif namespace { @@ -85,8 +76,6 @@ void RunTest(const std::vector& input_dims) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(onnxruntime::test::DefaultRocmExecutionProvider()); #endif test.Run(onnxruntime::test::OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } diff --git a/orttraining/orttraining/test/training_ops/cuda/conv_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/conv_grad_test.cc index 691856c688c9f..7638596bc3997 100644 --- a/orttraining/orttraining/test/training_ops/cuda/conv_grad_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/conv_grad_test.cc @@ -11,7 +11,7 @@ namespace test { using namespace std; using namespace onnxruntime::test; -#if USE_CUDA || USE_ROCM +#if USE_CUDA namespace { struct ConvGradOpAttributes { @@ -315,7 +315,7 @@ TEST(ConvTest, Conv3D_Bias) { TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW, dB}, {dX_shape, dW_shape, dB_shape}); TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW, dB}, {dX_shape, dW_shape, dB_shape}, true); } -#endif // USE_CUDA || USE_ROCM +#endif // USE_CUDA } // namespace test } // namespace contrib diff --git a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc index 61bd9c19f3541..9dd0e59438be9 100644 --- a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc @@ -15,8 +15,6 @@ namespace test { #if USE_CUDA constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; -#elif USE_ROCM -constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider; #endif static void TestSoftmaxCrossEntropy(const std::vector& X_dims, @@ -423,8 +421,6 @@ static void TestSCELoss(const char* op, int opset_version, []() -> std::unique_ptr { #ifdef USE_CUDA return DefaultCudaExecutionProvider(); -#elif USE_ROCM - return DefaultRocmExecutionProvider(); #endif }, reduction, ignore_index, @@ -934,8 +930,6 @@ static void TestSoftmaxCrossEntropyLossInternalGrad(const std::vector& []() -> std::unique_ptr { #ifdef USE_CUDA return DefaultCudaExecutionProvider(); -#elif USE_ROCM - return DefaultRocmExecutionProvider(); #endif }, reduction, ignore_index, error_tolerance, has_bias, diff --git a/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc b/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc index dd5fa18ab3edd..ab87420d83dfb 100644 --- a/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc @@ -7,7 +7,7 @@ namespace onnxruntime { namespace test { -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(FlattenAndUnpadTest, Int32Type2D) { std::vector input = {1, 1, 3, 2, 0, 3, 0, 4, diff --git a/orttraining/orttraining/test/training_ops/cuda/gather_elements_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/gather_elements_grad_test.cc index 07d407da8e8d2..f1b545bda92ab 100644 --- a/orttraining/orttraining/test/training_ops/cuda/gather_elements_grad_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/gather_elements_grad_test.cc @@ -8,7 +8,7 @@ #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" -#if defined(ENABLE_STRIDED_TENSORS) && (defined(USE_CUDA) || defined(USE_ROCM)) +#if defined(ENABLE_STRIDED_TENSORS) && defined(USE_CUDA) #include "test/providers/kernel_compute_test_utils.h" #endif @@ -142,7 +142,7 @@ void RunTestWrapper() { RunTest({2, 1, 1, 2, 3, 2, 3}, {2, 1, 1, 2, 3, 2, 2}, true, -5LL); } -#if defined(ENABLE_STRIDED_TENSORS) && (defined(USE_CUDA) || defined(USE_ROCM)) +#if defined(ENABLE_STRIDED_TENSORS) && defined(USE_CUDA) template void RunKernelComputeTest(std::initializer_list input_dims, std::initializer_list indices_dims, std::initializer_list indices_strides = {}, bool has_axis = false, @@ -154,8 +154,6 @@ void RunKernelComputeTest(std::initializer_list input_dims, std::initia GetData(input_dims, indices_dims, indices_strides, new_axis, dY_data, indices_data, dX_data); #ifdef USE_CUDA const char* provider = kCudaExecutionProvider; -#else // USE_ROCM - const char* provider = kRocmExecutionProvider; #endif onnxruntime::test::KernelComputeTester test("GatherElementsGrad", provider, 1, kMSDomain); if (has_axis) test.AddAttribute("axis", axis); @@ -193,7 +191,7 @@ TEST(GatherElementsGrad, double) { RunTestWrapper(); } TEST(GatherElementsGrad, MLFloat16) { RunTestWrapper(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(GatherElementsGrad, BFloat16) { #ifdef USE_CUDA @@ -219,7 +217,7 @@ TEST(GatherElementsGrad, IndicesUpdatesDontMatch) { test.Run(onnxruntime::test::OpTester::ExpectResult::kExpectFailure, ""); } -#if defined(ENABLE_STRIDED_TENSORS) && (defined(USE_CUDA) || defined(USE_ROCM)) +#if defined(ENABLE_STRIDED_TENSORS) && defined(USE_CUDA) TEST(GatherElementsGrad, Strided_float) { RunKernelComputeTestWrapper(); } TEST(GatherElementsGrad, Strided_double) { RunKernelComputeTestWrapper(); } diff --git a/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc b/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc index 13ad2f6150acf..13ec8cdbe343c 100644 --- a/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc @@ -10,8 +10,6 @@ namespace test { #if USE_CUDA constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; -#elif USE_ROCM -constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider; #endif constexpr auto k_epsilon_default = 1e-5f; diff --git a/orttraining/orttraining/test/training_ops/cuda/mixed_precision_scale_test.cc b/orttraining/orttraining/test/training_ops/cuda/mixed_precision_scale_test.cc index 35b7e8d91d164..03101b9549269 100644 --- a/orttraining/orttraining/test/training_ops/cuda/mixed_precision_scale_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/mixed_precision_scale_test.cc @@ -147,7 +147,7 @@ TEST(CudaKernelTest, MixedPrecisionScaleH2H) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_bfloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -166,8 +166,6 @@ TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_bfloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -190,8 +188,6 @@ TEST(CudaKernelTest, MixedPrecisionScale_float_bfloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -214,8 +210,6 @@ TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_float) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -238,8 +232,6 @@ TEST(CudaKernelTest, MixedPrecisionScale_half_bfloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -262,12 +254,10 @@ TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_half) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } #endif } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/training_ops/cuda/negativeloglikelihood_test.cc b/orttraining/orttraining/test/training_ops/cuda/negativeloglikelihood_test.cc index c13ec612135a8..2d78575f6de62 100644 --- a/orttraining/orttraining/test/training_ops/cuda/negativeloglikelihood_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/negativeloglikelihood_test.cc @@ -10,8 +10,6 @@ namespace test { #if USE_CUDA constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; -#elif USE_ROCM -constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider; #endif static void TestNegativeLogLikelihoodLoss(CompareOpTester& test, const std::vector* X_dims, diff --git a/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc b/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc index 9a86955e09379..1b179c2eb35c6 100644 --- a/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc @@ -7,7 +7,7 @@ namespace onnxruntime { namespace test { -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(PadAndUnflattenTest, FloatType1D) { std::vector input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f}; diff --git a/orttraining/orttraining/test/training_ops/cuda/reduce_sum_test.cc b/orttraining/orttraining/test/training_ops/cuda/reduce_sum_test.cc index 335e6295fbd7b..23b92e13af19b 100644 --- a/orttraining/orttraining/test/training_ops/cuda/reduce_sum_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/reduce_sum_test.cc @@ -8,8 +8,6 @@ namespace test { #if USE_CUDA constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; -#elif USE_ROCM -constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider; #endif static void TestReduceSum(const std::vector& X_dims, diff --git a/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc index 8fc13af8816be..f28cc1fda4c47 100644 --- a/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc @@ -7,7 +7,7 @@ namespace onnxruntime::test { -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) namespace { @@ -22,8 +22,6 @@ TEST(ResizeGradTest, ResizeGradWithSizes) { std::vector> providers; #ifdef USE_CUDA providers.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - providers.emplace_back(DefaultRocmExecutionProvider()); #endif OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); @@ -51,8 +49,6 @@ TEST(ResizeGradTest, ResizeGradWithSizesHalf) { std::vector> providers; #ifdef USE_CUDA providers.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - providers.emplace_back(DefaultRocmExecutionProvider()); #endif OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); @@ -86,8 +82,6 @@ TEST(ResizeGradTest, ResizeGradWithSizesAndAlignCorners) { std::vector> providers; #ifdef USE_CUDA providers.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - providers.emplace_back(DefaultRocmExecutionProvider()); #endif OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); @@ -118,8 +112,6 @@ TEST(ResizeGradTest, ResizeGradWithScales) { std::vector> providers; #ifdef USE_CUDA providers.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - providers.emplace_back(DefaultRocmExecutionProvider()); #endif OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); @@ -152,8 +144,6 @@ TEST(ResizeGradTest, ResizeGradWithScalesHalf) { std::vector> providers; #ifdef USE_CUDA providers.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - providers.emplace_back(DefaultRocmExecutionProvider()); #endif OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); @@ -192,8 +182,6 @@ TEST(ResizeGradTest, ResizeGradWithScalesAndAlignCorners) { std::vector> providers; #ifdef USE_CUDA providers.emplace_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - providers.emplace_back(DefaultRocmExecutionProvider()); #endif OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); @@ -222,6 +210,6 @@ TEST(ResizeGradTest, ResizeGradWithScalesAndAlignCorners) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); } -#endif // defined(USE_CUDA) || defined(USE_ROCM) +#endif // defined(USE_CUDA) } // namespace onnxruntime::test diff --git a/orttraining/orttraining/test/training_ops/cuda/scale_test.cc b/orttraining/orttraining/test/training_ops/cuda/scale_test.cc index ec48cccf927b6..64875ec18835c 100644 --- a/orttraining/orttraining/test/training_ops/cuda/scale_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/scale_test.cc @@ -134,7 +134,7 @@ TEST(CudaKernelTest, ScaleHalfInt64ScaleDown) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) TEST(CudaKernelTest, ScaleBFloat16BFloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -152,8 +152,6 @@ TEST(CudaKernelTest, ScaleBFloat16BFloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -175,8 +173,6 @@ TEST(CudaKernelTest, ScaleFloatBFloat16) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } diff --git a/orttraining/orttraining/test/training_ops/cuda/scaled_sum_test.cc b/orttraining/orttraining/test/training_ops/cuda/scaled_sum_test.cc index ae55aaa1afb6b..ef6d8b9f46e3a 100644 --- a/orttraining/orttraining/test/training_ops/cuda/scaled_sum_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/scaled_sum_test.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" diff --git a/orttraining/orttraining/test/training_ops/cuda/softmax_dropout_test.cc b/orttraining/orttraining/test/training_ops/cuda/softmax_dropout_test.cc index 8c9ff298cad9c..1f81c6bed233c 100644 --- a/orttraining/orttraining/test/training_ops/cuda/softmax_dropout_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/softmax_dropout_test.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) #include #include @@ -74,8 +74,6 @@ void LaunchBiasSoftmaxDropoutTester(const std::vector& input_dims, cons std::vector> eps; #ifdef USE_CUDA eps.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - eps.push_back(DefaultRocmExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &eps); } @@ -175,8 +173,6 @@ void LaunchSoftmaxDropoutGradTester(const std::vector& dims, const std: std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } diff --git a/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc b/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc index 9ced022aab850..9a9467e74b506 100644 --- a/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc @@ -8,8 +8,6 @@ namespace test { #if USE_CUDA constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; -#elif USE_ROCM -constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider; #endif template @@ -215,22 +213,14 @@ TEST(CudaKernelTest, SoftmaxGrad_LargeTensor_LastAxis_Float16) { std::vector dY_dims{8, 16, 2048}; std::vector Y_dims{8, 16, 2048}; std::vector dX_dims{8, 16, 2048}; -#if USE_ROCM - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, false, 1.5e-2, 1.5e-2); -#else TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, false, 1e-3, 1e-3); -#endif } TEST(CudaKernelTest, SoftmaxGrad_LargeTensor_LastAxis_Float16_NoPowerOfTwo) { std::vector dY_dims{8, 16, 1500}; std::vector Y_dims{8, 16, 1500}; std::vector dX_dims{8, 16, 1500}; -#if USE_ROCM - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, false, 1.7e-2, 1.7e-2); -#else TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, false, 1e-3, 1e-3); -#endif } // large tensor to check cuda DNN softmax backward @@ -246,26 +236,16 @@ TEST(CudaKernelTest, SoftmaxGrad_LargeTensor_AllAxis_Float16) { std::vector dY_dims{8, 16, 512}; std::vector Y_dims{8, 16, 512}; std::vector dX_dims{8, 16, 512}; -#if USE_ROCM - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, false, 1.5e-2, 1.5e-2); - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, false, 1.5e-2, 1.5e-2); -#else TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, false, 1e-3, 1e-3); TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, false, 1e-3, 1e-3); -#endif } TEST(CudaKernelTest, SoftmaxGrad_LargeTensor_AllAxis_Float16_NoPowerOfTwo) { std::vector dY_dims{8, 16, 1500}; std::vector Y_dims{8, 16, 1500}; std::vector dX_dims{8, 16, 1500}; -#if USE_ROCM - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, false, 2.5e-2, 2.5e-2); - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, false, 2.5e-2, 2.5e-2); -#else TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, false, 1e-3, 1e-3); TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, false, 1e-3, 1e-3); -#endif } TEST(CudaKernelTest, LogSoftmaxGrad_SmallTensor_LastAxis) { @@ -294,23 +274,14 @@ TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor_LastAxis_Float16) { std::vector dY_dims{8, 16, 2048}; std::vector Y_dims{8, 16, 2048}; std::vector dX_dims{8, 16, 2048}; -#if USE_ROCM - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, true, 3.5e-2, 3.5e-2); -#else TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, true, 1e-3, 1e-3); -#endif } TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor_LastAxis_Float16_NoPowerOfTwo) { std::vector dY_dims{8, 16, 1500}; std::vector Y_dims{8, 16, 1500}; std::vector dX_dims{8, 16, 1500}; -#if USE_ROCM - // FIXME: Excessive numerical errors - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, true, 1.0, 5e-2); -#else TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 2, true, 1e-3, 1e-3); -#endif } TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor_AllAxis) { @@ -325,26 +296,16 @@ TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor_AllAxis_Float16) { std::vector dY_dims{8, 16, 512}; std::vector Y_dims{8, 16, 512}; std::vector dX_dims{8, 16, 512}; -#if USE_ROCM - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, true, 1.5e-2, 1.5e-2); - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, true, 1.5e-2, 1.5e-2); -#else TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, true, 1e-3, 1e-3); TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, true, 1e-3, 1e-3); -#endif } TEST(CudaKernelTest, LogSoftmaxGrad_LargeTensor_AllAxis_Float16_NoPowerOfTwo) { std::vector dY_dims{8, 16, 1500}; std::vector Y_dims{8, 16, 1500}; std::vector dX_dims{8, 16, 1500}; -#if USE_ROCM - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, true, 4.5e-2, 4.5e-2); - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, true, 4.5e-2, 4.5e-2); -#else TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 0, true, 1e-3, 1e-3); TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, 1, true, 1e-3, 1e-3); -#endif } static void TestSoftmaxGrad_13(const std::vector& dY_dims, diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 9e12fdcd2bb53..13f713da65eda 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -266,7 +266,7 @@ Status Parameter::ResetGrad() { if (device.Type() == OrtDevice::CPU) { memset(p_tensor->MutableDataRaw(), 0, p_tensor->SizeInBytes()); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) else if (device.Type() == OrtDevice::GPU) { ORT_NOT_IMPLEMENTED("Not implemented."); } diff --git a/orttraining/orttraining/training_ops/cuda/activation/bias_gelu_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/activation/bias_gelu_grad_impl.cu index 1963fe0185211..314442cca2e51 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/bias_gelu_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/activation/bias_gelu_grad_impl.cu @@ -97,12 +97,7 @@ void LaunchBiasGeluGradDxKernel( const int num_elements_per_thread = GridDim::maxElementsPerThread; -#ifdef USE_ROCM - // Optimization for ROCm MI100 - const int max_threads_per_block = 512; -#else const int max_threads_per_block = GridDim::maxThreadsPerBlock; -#endif int num_threads_per_block = std::min(static_cast(CeilDiv(bias_size, num_elements_per_thread)), max_threads_per_block); diff --git a/orttraining/orttraining/training_ops/cuda/math/bias_softmax_dropout.cc b/orttraining/orttraining/training_ops/cuda/math/bias_softmax_dropout.cc index 59d2081335d50..147373e77655e 100644 --- a/orttraining/orttraining/training_ops/cuda/math/bias_softmax_dropout.cc +++ b/orttraining/orttraining/training_ops/cuda/math/bias_softmax_dropout.cc @@ -39,11 +39,7 @@ struct DispatchBiasSoftmaxDropoutImpl { } // namespace -#ifdef USE_ROCM -#define BIAS_SOFTMAX_DROPOUT_TYPES float, MLFloat16 -#else #define BIAS_SOFTMAX_DROPOUT_TYPES float, MLFloat16, double -#endif ONNX_OPERATOR_KERNEL_EX(BiasSoftmaxDropout, kMSDomain, 1, kCudaExecutionProvider, (*KernelDefBuilder::Create()) diff --git a/orttraining/orttraining/training_ops/cuda/math/bias_softmax_dropout_impl.cu b/orttraining/orttraining/training_ops/cuda/math/bias_softmax_dropout_impl.cu index 72fbbf53bfb21..0c144dab8ca20 100644 --- a/orttraining/orttraining/training_ops/cuda/math/bias_softmax_dropout_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/math/bias_softmax_dropout_impl.cu @@ -21,11 +21,7 @@ __global__ void BiasSoftmaxDropoutKernel(T* dropout_output_data, bool* mask_data constexpr int kNextPowOfTwo = 1 << Log2Elements; constexpr int kWarpSize = kNextPowOfTwo < GPU_WARP_SIZE ? kNextPowOfTwo : GPU_WARP_SIZE; constexpr int kWarpIterations = kNextPowOfTwo / kWarpSize; -#ifdef USE_ROCM - constexpr int kWarpBatch = 1; -#else constexpr int kWarpBatch = (kNextPowOfTwo <= 128) ? 2 : 1; -#endif int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kWarpBatch; // last warp may have fewer batches. @@ -201,13 +197,8 @@ Status BiasSoftmaxDropoutImpl(cudaStream_t stream, const cudaDeviceProp& prop, c int warp_size = std::min(next_power_of_two, GPU_WARP_SIZE_HOST); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. -#ifdef USE_ROCM - int batches_per_warp = 1; - constexpr int threads_per_block = 256; -#else int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; constexpr int threads_per_block = 128; -#endif constexpr int t_vec4_alignment = std::alignment_of>::value; constexpr int mask_vec4_alignment = std::alignment_of>::value; diff --git a/orttraining/orttraining/training_ops/cuda/math/softmax_dropout_grad.cc b/orttraining/orttraining/training_ops/cuda/math/softmax_dropout_grad.cc index 7399d2e1e933a..51960b3540067 100644 --- a/orttraining/orttraining/training_ops/cuda/math/softmax_dropout_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/math/softmax_dropout_grad.cc @@ -37,11 +37,7 @@ struct DispatchSoftmaxDropoutGradImpl { } // namespace -#ifdef USE_ROCM -#define SOFTMAX_DROPOUT_GRAD_TYPES float, MLFloat16 -#else #define SOFTMAX_DROPOUT_GRAD_TYPES float, MLFloat16, double -#endif ONNX_OPERATOR_KERNEL_EX(SoftmaxDropoutGrad, kMSDomain, 1, kCudaExecutionProvider, (*KernelDefBuilder::Create()) diff --git a/orttraining/orttraining/training_ops/cuda/math/softmax_dropout_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/math/softmax_dropout_grad_impl.cu index b48ab1b718786..60a15d1386f4d 100644 --- a/orttraining/orttraining/training_ops/cuda/math/softmax_dropout_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/math/softmax_dropout_grad_impl.cu @@ -19,11 +19,7 @@ __global__ void SoftmaxDropoutGradKernel(T* input_grad_data, const T* output_gra constexpr int kNextPowOfTwo = 1 << Log2Elements; constexpr int kWarpSize = kNextPowOfTwo < GPU_WARP_SIZE ? kNextPowOfTwo : GPU_WARP_SIZE; constexpr int kWarpIterations = kNextPowOfTwo / kWarpSize; -#ifdef USE_ROCM - constexpr int kWarpBatch = 1; -#else constexpr int kWarpBatch = (kNextPowOfTwo <= 128) ? 2 : 1; -#endif int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kWarpBatch; // last warp may have fewer batches. @@ -146,13 +142,8 @@ Status SoftmaxDropoutGradImpl(cudaStream_t stream, cudnnHandle_t cudnn_handle, T int warp_size = std::min(next_power_of_two, GPU_WARP_SIZE_HOST); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. -#ifdef USE_ROCM - int batches_per_warp = 1; - constexpr int threads_per_block = 256; -#else int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; constexpr int threads_per_block = 128; -#endif constexpr int t_vec4_alignment = std::alignment_of>::value; constexpr int mask_vec4_alignment = std::alignment_of>::value; diff --git a/orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc b/orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc index 5c73a25fb4c9a..1be96da593cd5 100644 --- a/orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc @@ -30,12 +30,7 @@ struct DispatchSoftmaxGradImpl { } // namespace -// MIOpen doesn't support double so ROCm kernel doesn't have double support for now. -#ifdef USE_ROCM -#define SOFTMAX_GRAD_TYPES float, MLFloat16, BFloat16 -#else #define SOFTMAX_GRAD_TYPES float, double, MLFloat16, BFloat16 -#endif #define REGISTER_SOFTMAX_GRAD_KERNEL(name) \ ONNX_OPERATOR_KERNEL_EX( \ diff --git a/orttraining/orttraining/training_ops/cuda/math/softmax_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/math/softmax_grad_impl.cu index 3b5bd895c1f54..0764f60b2d1b5 100644 --- a/orttraining/orttraining/training_ops/cuda/math/softmax_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/math/softmax_grad_impl.cu @@ -36,11 +36,7 @@ __global__ void softmax_warp_backward(output_t* gradInput, const input_t* grad, constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; -#ifdef USE_ROCM - constexpr int WARP_BATCH = 1; -#else constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; -#endif int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; @@ -227,11 +223,7 @@ template Status SoftmaxGradImpl(cudaStream_t stream, cudnnHandle_t cudnn_handle, T* input_grad, const T* output_grad, const T* softmax_output, int element_count, int batch_count, bool is_log_softmax) { if (element_count == 0) return Status::OK(); -#ifdef USE_ROCM - if (element_count <= 1024 && element_count * sizeof(T) <= 4096) { -#else if (element_count <= 2048 && element_count * sizeof(T) <= 4096) { -#endif typedef AccumulationType_t AccT; int log2_elements = log2_ceil(element_count); const int next_power_of_two = 1 << log2_elements; @@ -240,13 +232,8 @@ Status SoftmaxGradImpl(cudaStream_t stream, cudnnHandle_t cudnn_handle, T* input int warp_size = std::min(next_power_of_two, GPU_WARP_SIZE_HOST); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. -#ifdef USE_ROCM - int batches_per_warp = 1; - constexpr int threads_per_block = 256; -#else int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; constexpr int threads_per_block = 128; -#endif int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc index d5f4303414deb..270fc03e7ff7a 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc @@ -80,12 +80,8 @@ Status LayerNormGrad::ComputeInternal(OpKernelContext* p_op bias_grad_data = reinterpret_cast(bias_grad->template MutableData()); } -#ifndef USE_ROCM constexpr int part_size = 16; -#else - // Optimization for ROCm MI100 - constexpr int part_size = 64; -#endif + auto part_grad_gamma = GetScratchBuffer(part_size * n2, p_op_kernel_context->GetComputeStream()); auto part_grad_beta = GetScratchBuffer(part_size * n2, p_op_kernel_context->GetComputeStream()); @@ -135,12 +131,8 @@ Status InvertibleLayerNormGrad::ComputeInternal(OpKernelContext* p_op_k auto scale_grad_data = reinterpret_cast(scale_grad->template MutableData()); auto bias_grad_data = reinterpret_cast(bias_grad->template MutableData()); -#ifndef USE_ROCM constexpr int part_size = 16; -#else - // Optimization for ROCm MI100 - constexpr int part_size = 64; -#endif + auto part_grad_gamma = GetScratchBuffer(part_size * n2, p_op_kernel_context->GetComputeStream()); auto part_grad_beta = GetScratchBuffer(part_size * n2, p_op_kernel_context->GetComputeStream()); diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h index 75f8c243d3425..bb49ec2743123 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h @@ -3,11 +3,7 @@ #pragma once -#ifdef USE_ROCM -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#else #include "core/providers/cuda/shared_inc/cuda_utils.h" -#endif namespace onnxruntime { namespace cuda { diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h index 8b015179cebd0..89d145294dfc4 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h @@ -3,11 +3,7 @@ #pragma once -#ifdef USE_ROCM -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#else #include "core/providers/cuda/shared_inc/cuda_utils.h" -#endif namespace onnxruntime { namespace cuda { diff --git a/orttraining/orttraining/training_ops/rocm/activation/gelu_grad_impl_common.cuh b/orttraining/orttraining/training_ops/rocm/activation/gelu_grad_impl_common.cuh deleted file mode 100644 index 2377aae9abb54..0000000000000 --- a/orttraining/orttraining/training_ops/rocm/activation/gelu_grad_impl_common.cuh +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/cu_inc/common.cuh" -#include "orttraining/training_ops/cpu/activation/gelu_computation_mode.h" - -namespace onnxruntime { -namespace rocm { - -template -__device__ __inline__ T ComputeGeluGradScalar(T dY, T X, gelu_computation_mode::Default) { - const T kAlpha = T(M_2_SQRTPI) * T(M_SQRT1_2) * T(0.5); - return dY * (_Normcdf(X) + X * kAlpha * _Exp(-T(0.5) * X * X)); -} - -template -__device__ __inline__ T ComputeGeluGradScalar(T dY, T X, gelu_computation_mode::Approximation) { - // copied and adapted from DeepSpeed: - // https://github.com/microsoft/DeepSpeed/blob/f5025506de37f617a93eabc2aed7cc4f4bfd7d80/csrc/transformer/gelu_kernels.cu#L10 - - const float X_float = static_cast(X); - - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715f; - - constexpr float one = 1.0; - constexpr float two = 2.0; - - float x2mul = X_float * X_float * mul_param; - - // float tan_h = tanhf(sqrt_param * (X_float + X_float * x2mul)); - float u = two * sqrt_param * (X_float + X_float * x2mul); - float emu = __expf(-u); - float tan_h = two / (one + emu) - one; - - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = X_float * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - - return dY * static_cast(dg1 + dg2 + dg3); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/math/div_grad.cc b/orttraining/orttraining/training_ops/rocm/math/div_grad.cc deleted file mode 100644 index 03669e33e7d2e..0000000000000 --- a/orttraining/orttraining/training_ops/rocm/math/div_grad.cc +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/rocm/math/div_grad.h" -#include "orttraining/training_ops/rocm/math/div_grad_impl.h" -#include "core/providers/rocm/math/binary_elementwise_ops.h" - -using namespace onnxruntime::common; -namespace onnxruntime { -namespace rocm { - -#define DIVGRAD_REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - DivGrad, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - DivGrad); - -DIVGRAD_REGISTER_KERNEL_TYPED(MLFloat16) -DIVGRAD_REGISTER_KERNEL_TYPED(float) -// DIVGRAD_REGISTER_KERNEL_TYPED(double) - -TensorShapeVector prepended_dimension_1(const TensorShape& shape, size_t total_rank) { - size_t input_rank = shape.NumDimensions(); - if (input_rank == total_rank) - return shape.AsShapeVector(); - - TensorShapeVector dims(total_rank, 1); - - // https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md - // for property 3 of Multidirectional Broadcasting, we need to prepended with a dimension of length 1. - if (input_rank > 0) - std::copy(shape.GetDims().begin(), shape.GetDims().end(), &dims[total_rank - input_rank]); - return dims; -} - -template -Status DivGrad::ComputeInternal(OpKernelContext* context) const { - typedef typename ToHipType::MappedType HipT; - - const Tensor* dy_tensor = context->Input(0); - const Tensor* a_tensor = context->Input(1); - const Tensor* b_tensor = context->Input(2); - const TensorShape& a_shape = a_tensor->Shape(); - const TensorShape& b_shape = b_tensor->Shape(); - const TensorShape& dy_shape = dy_tensor->Shape(); - - // output shapes shall match its corresponding inputs - Tensor* da_output_tensor = context->Output(0, a_shape); - Tensor* db_output_tensor = context->Output(1, b_shape); - if (!da_output_tensor && !db_output_tensor) - return Status::OK(); - - BinaryElementwisePreparation prepare; - ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(a_tensor, b_tensor, - // TODO: BinaryElementwiseBroadcastPrepare shall take dy_tensor as const Tensor*. - const_cast(dy_tensor), &prepare)); - const HipT* prepare_a_data = reinterpret_cast(prepare.lhs_tensor->template Data()); - const HipT* prepare_b_data = reinterpret_cast(prepare.rhs_tensor->template Data()); - const HipT* prepare_dy_data = reinterpret_cast(prepare.output_tensor->template Data()); - T* da_data = da_output_tensor ? da_output_tensor->template MutableData() : nullptr; - T* db_data = db_output_tensor ? db_output_tensor->template MutableData() : nullptr; - - switch (prepare.output_rank_or_simple_broadcast) { - case static_cast(SimpleBroadcast::NoBroadcast): - ImplDivGradSimple( - Stream(context), - SimpleBroadcast::NoBroadcast, - prepare_a_data, - prepare_b_data, - prepare_dy_data, - dy_shape.Size(), - reinterpret_cast(da_data), - reinterpret_cast(db_data)); - break; - case static_cast(SimpleBroadcast::LeftScalar): { - T* temp_da_data = nullptr; - IAllocatorUniquePtr temp_da_allocator; - if (da_output_tensor) { - temp_da_allocator = GetScratchBuffer(dy_shape.Size(), context->GetComputeStream()); - temp_da_data = temp_da_allocator.get(); - } - - ImplDivGradSimple( - Stream(context), - SimpleBroadcast::LeftScalar, - prepare_a_data, - prepare_b_data, - prepare_dy_data, - dy_shape.Size(), - reinterpret_cast(temp_da_data), - reinterpret_cast(db_data)); - - if (da_output_tensor) { - auto a_output_dims = prepended_dimension_1(a_shape, dy_shape.NumDimensions()); - ORT_RETURN_IF_ERROR((ReduceKernelShared( - temp_da_data, - dy_shape, - da_data, - TensorShape({}), - MIOPEN_REDUCE_TENSOR_ADD, - GetMiopenHandle(context), - context->GetComputeStream(), - a_output_dims))); - } - break; - } - case static_cast(SimpleBroadcast::RightScalar): { - T* temp_db_data = nullptr; - IAllocatorUniquePtr temp_db_allocator; - if (db_output_tensor) { - temp_db_allocator = GetScratchBuffer(dy_shape.Size(), context->GetComputeStream()); - temp_db_data = temp_db_allocator.get(); - } - ImplDivGradSimple( - Stream(context), - SimpleBroadcast::RightScalar, - prepare_a_data, - prepare_b_data, - prepare_dy_data, - dy_shape.Size(), - reinterpret_cast(da_data), - reinterpret_cast(temp_db_data)); - - if (db_output_tensor) { - auto b_output_dims = prepended_dimension_1(b_shape, dy_shape.NumDimensions()); - ORT_RETURN_IF_ERROR((ReduceKernelShared( - temp_db_data, - dy_shape, - db_data, - TensorShape({}), - MIOPEN_REDUCE_TENSOR_ADD, - GetMiopenHandle(context), - context->GetComputeStream(), - b_output_dims))); - } - break; - } - case static_cast(SimpleBroadcast::RightPerChannelBatch1): - case static_cast(SimpleBroadcast::RightPerChannelBatchN): { - T* temp_db_data = nullptr; - IAllocatorUniquePtr temp_db_allocator; - if (db_output_tensor) { - temp_db_allocator = GetScratchBuffer(dy_shape.Size(), context->GetComputeStream()); - temp_db_data = temp_db_allocator.get(); - } - if (prepare.output_rank_or_simple_broadcast == static_cast(SimpleBroadcast::RightPerChannelBatch1)) { - // lhs(1,C,H) and rhs (C,1) - ImplDivGradRhsPerChannelBatch1( - Stream(context), - prepare_a_data, - prepare_b_data, - prepare_dy_data, - dy_shape.Size(), - prepare.fdm_H, - reinterpret_cast(da_data), - reinterpret_cast(temp_db_data)); - } else { - // lhs(N,C,H) and rhs (C,1) - ImplDivGradRhsPerChannelBatchN( - Stream(context), - prepare_a_data, - prepare_b_data, - prepare_dy_data, - dy_shape.Size(), - prepare.fdm_H, - prepare.fdm_C, - reinterpret_cast(da_data), - reinterpret_cast(temp_db_data)); - } - - if (db_output_tensor) { - auto b_output_dims = prepended_dimension_1(b_shape, dy_shape.NumDimensions()); - ORT_RETURN_IF_ERROR((ReduceKernelShared( - temp_db_data, - dy_shape, - db_data, - b_shape, - MIOPEN_REDUCE_TENSOR_ADD, - GetMiopenHandle(context), - context->GetComputeStream(), - b_output_dims))); - } - break; - } - default: { - bool need_reduce_da = da_output_tensor && a_shape.Size() != dy_shape.Size(); - bool need_reduce_db = db_output_tensor && b_shape.Size() != dy_shape.Size(); - IAllocatorUniquePtr temp_da_allocator, temp_db_allocator; - T* da_data_ref = nullptr; - if (da_output_tensor) { - if (need_reduce_da) { - temp_da_allocator = GetScratchBuffer(dy_shape.Size(), context->GetComputeStream()); - da_data_ref = temp_da_allocator.get(); - } else { - da_data_ref = da_data; - } - } - T* db_data_ref = nullptr; - if (db_output_tensor) { - if (need_reduce_db) { - temp_db_allocator = GetScratchBuffer(dy_shape.Size(), context->GetComputeStream()); - db_data_ref = temp_db_allocator.get(); - } else { - db_data_ref = db_data; - } - } - ImplDivGrad( - Stream(context), - prepare.output_rank_or_simple_broadcast, - prepare.lhs_padded_strides, - prepare_a_data, - prepare.rhs_padded_strides, - prepare_b_data, - prepare_dy_data, - dy_shape.Size(), - prepare.fdm_output_strides, - reinterpret_cast(da_data_ref), - reinterpret_cast(db_data_ref)); - - if (need_reduce_da) { - auto a_output_dims = prepended_dimension_1(a_shape, dy_shape.NumDimensions()); - ORT_RETURN_IF_ERROR((ReduceKernelShared( - da_data_ref, - dy_shape, - da_data, - a_shape, - MIOPEN_REDUCE_TENSOR_ADD, - GetMiopenHandle(context), - context->GetComputeStream(), - a_output_dims))); - } - - if (need_reduce_db) { - auto b_output_dims = prepended_dimension_1(b_shape, dy_shape.NumDimensions()); - ORT_RETURN_IF_ERROR((ReduceKernelShared( - db_data_ref, - dy_shape, - db_data, - b_shape, - MIOPEN_REDUCE_TENSOR_ADD, - GetMiopenHandle(context), - context->GetComputeStream(), - b_output_dims))); - } - } - } - return Status::OK(); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.cc b/orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.cc deleted file mode 100644 index b1072796bb4fa..0000000000000 --- a/orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.cc +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/rocm/nn/batch_norm_grad.h" -#include "core/providers/common.h" -#include "core/providers/rocm/miopen_common.h" -#include "core/providers/cpu/nn/batch_norm_helper.h" -#include "core/providers/rocm/math/unary_elementwise_ops_impl.h" - -using namespace std; -namespace onnxruntime { -namespace rocm { - -#define REGISTER_GRADIENT_KERNEL_TYPED(T, T1, T2) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - BatchNormalizationGrad, \ - kMSDomain, \ - 1, \ - T##_##T1##_##T2, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()).TypeConstraint("T1", DataTypeImpl::GetTensorType()).TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - BatchNormalizationGrad); - -template -Status BatchNormalizationGrad::ComputeInternal(OpKernelContext* ctx) const { - typedef typename ToHipType::MappedType HipT; - typedef typename ToHipType::MappedType HipT1; - typedef typename ToHipType::MappedType HipT2; - - const Tensor* dY = ctx->Input(0); - const Tensor* X = ctx->Input(1); - const Tensor* Scale = ctx->Input(2); - const Tensor* saved_mean = ctx->Input(3); - // miopenBatchNormalizationBackward() claims to use `savedInvVariance`, but the value - // is actually equal to the batch inv_std, so we use name `saved_inv_std` here. - const Tensor* saved_inv_std = ctx->Input(4); - const TensorShape input_shape = X->Shape(); - const TensorShape channel_shape = saved_mean->Shape(); - - // no B here, but B has same size as Scale, so can validate inputs for gradient with this substitute - ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, Scale, Scale, saved_mean, saved_inv_std)); - - auto dY_data = reinterpret_cast(dY->template Data()); - auto X_data = reinterpret_cast(X->template Data()); - auto Scale_data = reinterpret_cast(Scale->template Data()); - auto saved_mean_data = reinterpret_cast(saved_mean->template Data()); - auto saved_inv_std_data = reinterpret_cast(saved_inv_std->template Data()); - - auto dX_data = reinterpret_cast(ctx->Output(0, input_shape)->template MutableData()); - auto dScale_data = reinterpret_cast(ctx->Output(1, channel_shape)->template MutableData()); - auto dBias_data = reinterpret_cast(ctx->Output(2, channel_shape)->template MutableData()); - - const auto alpha = Consts::One; - const auto beta = Consts::Zero; - - MiopenTensor input_tensor, scale_bias_tensor; - vector new_dims; - BatchNormHelper::NormalizeDims(input_shape, new_dims); - ORT_RETURN_IF_ERROR(input_tensor.Set(new_dims, MiopenTensor::GetDataType())); - // for fp16 input, `scale_bias_tensor` will have a float type; otherwise it will be the same as input type. - ORT_RETURN_IF_ERROR(scale_bias_tensor.Set(input_tensor, miopen_batch_norm_mode_)); - - const int64_t C = new_dims[1]; - auto p_scale = reinterpret_cast(Scale_data); - auto p_saved_mean = reinterpret_cast(saved_mean_data); - auto p_saved_inv_std = reinterpret_cast(saved_inv_std_data); - auto p_dScale = reinterpret_cast(dScale_data); - auto p_dBias = reinterpret_cast(dBias_data); - - IAllocatorUniquePtr p_f_scale, p_f_dScale, p_f_dBias, p_f_saved_mean, p_f_saved_inv_std; - - if (std::is_same::value) { - p_f_scale = GetScratchBuffer(C, ctx->GetComputeStream()); - p_f_dScale = GetScratchBuffer(C, ctx->GetComputeStream()); - p_f_dBias = GetScratchBuffer(C, ctx->GetComputeStream()); - - Impl_Cast(Stream(ctx), Scale_data, p_f_scale.get(), C); - - p_scale = p_f_scale.get(); - p_dScale = p_f_dScale.get(); - p_dBias = p_f_dBias.get(); - } - - if (std::is_same::value) { - p_f_saved_mean = GetScratchBuffer(C, ctx->GetComputeStream()); - p_f_saved_inv_std = GetScratchBuffer(C, ctx->GetComputeStream()); - - Impl_Cast(Stream(ctx), saved_mean_data, p_f_saved_mean.get(), C); - Impl_Cast(Stream(ctx), saved_inv_std_data, p_f_saved_inv_std.get(), C); - - p_saved_mean = p_f_saved_mean.get(); - p_saved_inv_std = p_f_saved_inv_std.get(); - } - - MIOPEN_RETURN_IF_ERROR(miopenBatchNormalizationBackward( - GetMiopenHandle(ctx), - miopen_batch_norm_mode_, - &alpha, - &beta, - &alpha, - &beta, - input_tensor, - X_data, - input_tensor, - dY_data, - input_tensor, - dX_data, - scale_bias_tensor, - p_scale, - p_dScale, - p_dBias, - epsilon_, - p_saved_mean, - p_saved_inv_std)); - - if (std::is_same::value) { - Impl_Cast(Stream(ctx), reinterpret_cast(p_dScale), dScale_data, C); - Impl_Cast(Stream(ctx), reinterpret_cast(p_dBias), dBias_data, C); - } - - return Status::OK(); -} - -#define SPECIALIZED_GRADIENT(T, T1, T2) \ - REGISTER_GRADIENT_KERNEL_TYPED(T, T1, T2) \ - template Status BatchNormalizationGrad::ComputeInternal(OpKernelContext* ctx) const; - -SPECIALIZED_GRADIENT(float, float, float) -// MIOpen kernel does not support double, disable for now. -// SPECIALIZED_GRADIENT(double, double, double) -SPECIALIZED_GRADIENT(MLFloat16, MLFloat16, MLFloat16) -SPECIALIZED_GRADIENT(MLFloat16, MLFloat16, float) -SPECIALIZED_GRADIENT(MLFloat16, float, float) - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.h b/orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.h deleted file mode 100644 index 63d2370076bab..0000000000000 --- a/orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_kernel.h" -#include "core/providers/rocm/miopen_common.h" - -namespace onnxruntime { -namespace rocm { - -template -class BatchNormalizationGrad final : public RocmKernel { - public: - BatchNormalizationGrad(const OpKernelInfo& info) - : RocmKernel{info}, - miopen_batch_norm_mode_(miopenBNSpatial) { - float tmp_epsilon; - ORT_ENFORCE(info.GetAttr("epsilon", &tmp_epsilon).IsOK()); - epsilon_ = ClampMiopenBatchNormEpsilon(static_cast(tmp_epsilon)); - - // spatial or not - int64_t tmp_spatial; - if (info.GetAttr("spatial", &tmp_spatial).IsOK()) { - spatial_ = tmp_spatial; - } - - if (spatial_ == 0) { - miopen_batch_norm_mode_ = miopenBNPerActivation; - } - } - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - double epsilon_; - int64_t spatial_ = 1; // default as per spec - miopenBatchNormMode_t miopen_batch_norm_mode_; -}; - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/nn/batch_norm_internal.cc b/orttraining/orttraining/training_ops/rocm/nn/batch_norm_internal.cc deleted file mode 100644 index dbd1f95ddee95..0000000000000 --- a/orttraining/orttraining/training_ops/rocm/nn/batch_norm_internal.cc +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/rocm/nn/batch_norm_internal.h" -#include "core/providers/common.h" -#include "core/providers/rocm/miopen_common.h" -#include "core/providers/cpu/nn/batch_norm_helper.h" -#include "core/providers/rocm/math/unary_elementwise_ops_impl.h" - -using namespace std; -namespace onnxruntime { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T, T1, T2) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - BatchNormInternal, \ - kMSDomain, \ - 1, \ - T##_##T1##_##T2, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .Alias(3, 1) \ - .Alias(4, 2) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - BatchNormInternal); - -template -Status BatchNormInternal::ComputeInternal(OpKernelContext* p_op_kernel_context) const { - typedef typename ToHipType::MappedType HipT; - typedef typename ToHipType::MappedType HipT1; - typedef typename ToHipType::MappedType HipT2; - - const Tensor* X = p_op_kernel_context->Input(0); - const Tensor* scale = p_op_kernel_context->Input(1); - const Tensor* B = p_op_kernel_context->Input(2); - const Tensor* mean = p_op_kernel_context->Input(3); - const Tensor* var = p_op_kernel_context->Input(4); - - ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, spatial_ == 1)); - - const TensorShape& x_shape = X->Shape(); - const TensorShape& channel_shape = mean->Shape(); - - Tensor* Y = p_op_kernel_context->Output(0, x_shape); - Tensor* running_mean = p_op_kernel_context->Output(1, channel_shape); - Tensor* running_var = p_op_kernel_context->Output(2, channel_shape); - Tensor* saved_mean = p_op_kernel_context->Output(3, channel_shape); - // miopenBatchNormalizationForwardTraining() claims to output `resultSaveInvVariance`, but the value - // is actually equal to the batch inv_std, so we use name `saved_inv_std` here. - Tensor* saved_inv_std = p_op_kernel_context->Output(4, channel_shape); - - auto x_data = reinterpret_cast(X->template Data()); - auto scale_data = reinterpret_cast(scale->template Data()); - auto b_data = reinterpret_cast(B->template Data()); - auto mean_data = reinterpret_cast(mean->template Data()); - auto var_data = reinterpret_cast(var->template Data()); - - auto y_data = reinterpret_cast(Y->template MutableData()); - - // In MIOpenBatchNormForward, alpha and beta are not const. - float alpha = 1.0; - float beta = 0.0; - - MiopenTensor data_desc, bn_tensor_desc; - vector new_dims; - BatchNormHelper::NormalizeDims(x_shape, new_dims); - ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, MiopenTensor::GetDataType())); - // for fp16 input, `bn_tensor_desc` will have a float type; otherwise it will be the same as input type. - ORT_RETURN_IF_ERROR(bn_tensor_desc.Set(data_desc, miopen_batch_norm_mode_)); - - auto running_mean_data = reinterpret_cast(running_mean->template MutableData()); - auto running_var_data = reinterpret_cast(running_var->template MutableData()); - auto saved_mean_data = reinterpret_cast(saved_mean->template MutableData()); - auto saved_inv_std_data = reinterpret_cast(saved_inv_std->template MutableData()); - - auto p_scale = reinterpret_cast(scale_data); - auto p_B = reinterpret_cast(b_data); - auto p_running_mean = reinterpret_cast(running_mean_data); - auto p_running_var = reinterpret_cast(running_var_data); - auto p_saved_mean = reinterpret_cast(saved_mean_data); - auto p_saved_inv_std = reinterpret_cast(saved_inv_std_data); - - const int64_t C = new_dims[1]; - IAllocatorUniquePtr p_f_scale, p_f_B, p_f_running_mean, p_f_running_var, p_f_saved_mean, p_f_saved_inv_std; - - if (std::is_same::value) { - // Convert scale/B to float - p_f_scale = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); - p_f_B = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); - - Impl_Cast(Stream(p_op_kernel_context), scale_data, p_f_scale.get(), C); - Impl_Cast(Stream(p_op_kernel_context), b_data, p_f_B.get(), C); - - p_scale = p_f_scale.get(); - p_B = p_f_B.get(); - } - - if (std::is_same::value) { - // Convert mean/var to float - p_f_running_mean = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); - p_f_running_var = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); - p_f_saved_mean = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); - p_f_saved_inv_std = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); - - Impl_Cast(Stream(p_op_kernel_context), mean_data, p_f_running_mean.get(), C); - Impl_Cast(Stream(p_op_kernel_context), var_data, p_f_running_var.get(), C); - - p_running_mean = p_f_running_mean.get(); - p_running_var = p_f_running_var.get(); - p_saved_mean = p_f_saved_mean.get(); - p_saved_inv_std = p_f_saved_inv_std.get(); - } else if (mean_data != running_mean_data) { - HIP_RETURN_IF_ERROR( - hipMemcpyAsync(running_mean_data, mean_data, C * sizeof(T2), hipMemcpyDeviceToDevice, Stream(p_op_kernel_context))); - HIP_RETURN_IF_ERROR( - hipMemcpyAsync(running_var_data, var_data, C * sizeof(T2), hipMemcpyDeviceToDevice, Stream(p_op_kernel_context))); - } - - // NOTE: in miopenBatchNorm, biased std/var is used when calculating `save_inv_std` and `y`, while - // `running_var` is updated using unbiased `batch_var`: - // running_var = (1 - momentum_) * unbiased_batch_var + momentum_ * running_var - // This is inconsistent with BatchNormalization Onnx spec, which uses population variance (biased). - MIOPEN_RETURN_IF_ERROR(miopenBatchNormalizationForwardTraining( - GetMiopenHandle(p_op_kernel_context), - miopen_batch_norm_mode_, - &alpha, - &beta, - data_desc, - x_data, - data_desc, - y_data, - bn_tensor_desc, - const_cast(p_scale), - const_cast(p_B), - 1.0 - momentum_, - p_running_mean, - p_running_var, - epsilon_, - p_saved_mean, - p_saved_inv_std)); - - if (std::is_same::value) { - Impl_Cast(Stream(p_op_kernel_context), reinterpret_cast(p_running_mean), running_mean_data, C); - Impl_Cast(Stream(p_op_kernel_context), reinterpret_cast(p_running_var), running_var_data, C); - Impl_Cast(Stream(p_op_kernel_context), reinterpret_cast(p_saved_mean), saved_mean_data, C); - Impl_Cast(Stream(p_op_kernel_context), reinterpret_cast(p_saved_inv_std), saved_inv_std_data, C); - } - - return Status::OK(); -} - -#define SPECIALIZED_COMPUTE(T, T1, T2) \ - REGISTER_KERNEL_TYPED(T, T1, T2) \ - template Status BatchNormInternal::ComputeInternal(OpKernelContext* ctx) const; - -SPECIALIZED_COMPUTE(float, float, float) -// MIOpen kernel does not support double, disable for now. -// SPECIALIZED_COMPUTE(double, double, double) -SPECIALIZED_COMPUTE(MLFloat16, MLFloat16, MLFloat16) -SPECIALIZED_COMPUTE(MLFloat16, MLFloat16, float) -SPECIALIZED_COMPUTE(MLFloat16, float, float) - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/nn/batch_norm_internal.h b/orttraining/orttraining/training_ops/rocm/nn/batch_norm_internal.h deleted file mode 100644 index d65b66120a78c..0000000000000 --- a/orttraining/orttraining/training_ops/rocm/nn/batch_norm_internal.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_kernel.h" -#include "core/providers/rocm/miopen_common.h" - -namespace onnxruntime { -namespace rocm { - -template -class BatchNormInternal final : public RocmKernel { - public: - BatchNormInternal(const OpKernelInfo& op_kernel_info) - : RocmKernel{op_kernel_info}, - miopen_batch_norm_mode_(miopenBNSpatial), - momentum_(0.9) { - float tmp_epsilon; - ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &tmp_epsilon).IsOK()); - epsilon_ = ClampMiopenBatchNormEpsilon(static_cast(tmp_epsilon)); - - // spatial or not - int64_t tmp_spatial; - if (op_kernel_info.GetAttr("spatial", &tmp_spatial).IsOK()) { - spatial_ = tmp_spatial; - } - - if (spatial_ == 0) { - miopen_batch_norm_mode_ = miopenBNPerActivation; - } - - float tmp_momentum; - if (op_kernel_info.GetAttr("momentum", &tmp_momentum).IsOK()) { - momentum_ = static_cast(tmp_momentum); - } - } - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - double epsilon_; - int64_t spatial_ = 1; // default as per spec - miopenBatchNormMode_t miopen_batch_norm_mode_; - double momentum_; -}; - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc b/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc deleted file mode 100644 index 3b1ed29cb0240..0000000000000 --- a/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc +++ /dev/null @@ -1,385 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// TODO Add exhaustive and default cases for algo. - -#include "orttraining/training_ops/rocm/nn/conv_grad.h" - -#include "core/providers/common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include - -namespace onnxruntime { -namespace rocm { - -#define REGISTER_GRADIENT_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX(ConvGrad, kMSDomain, 1, T, kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvGrad); - -REGISTER_GRADIENT_KERNEL_TYPED(float) -// MIOpen double support not currently implemented. -// REGISTER_GRADIENT_KERNEL_TYPED(double) -REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16) - -using T_BwdDataPerf = miopenConvAlgoPerf_t; -using T_BwdDataAlgo = miopenConvBwdDataAlgorithm_t; -using T_BwdFilterPerf = miopenConvAlgoPerf_t; -using T_BwdFilterAlgo = miopenConvBwdWeightsAlgorithm_t; - -miopenStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdDataAlgo algo, size_t* workspace_size) { - return miopenConvolutionBackwardDataGetWorkSpaceSize(args.handle, args.y_tensor, args.x_tensor, args.conv_desc, - args.w_desc, workspace_size); -} - -miopenStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdFilterAlgo algo, size_t* workspace_size) { - return miopenConvolutionBackwardWeightsGetWorkSpaceSize(args.handle, args.y_tensor, args.x_tensor, args.conv_desc, - args.w_desc, workspace_size); -} - -template -size_t GetMaxWorkspaceSize(const ConvArgs& args, const T_Algo* algo, int n_algo) { - // Calling hipMemGetInfo is not ideal, but our rocm allocator doesn't have a way to get this info. - size_t free, total; - HIP_CALL_THROW(hipMemGetInfo(&free, &total)); - // Assuming 10% of fragmentation. - free = static_cast(static_cast(free) * 0.9); - size_t max_workspace_size = 0; - for (int i = 0; i < n_algo; i++) { - miopenStatus_t status; - size_t workspace_size; - status = GetWorkspaceSize(args, algo[i], &workspace_size); - if (miopenStatusSuccess != status || workspace_size == 0 || workspace_size < max_workspace_size || - workspace_size > free) - continue; - max_workspace_size = workspace_size; - } - - return max_workspace_size; -} - -template -std::vector GetValidAlgorithms(const T_Perf* perf_results, int n_algo) { - std::vector result; - result.reserve(n_algo); - for (int i = 0; i < n_algo; i++) { - T_Perf perf = perf_results[i]; - result.emplace_back(perf); - } - ORT_ENFORCE(result.size() > 0, "No valid convolution algorithms available in MIOpen"); - return result; -} - -struct ConvParamsHash { - // ConvParams must be a trivial type because we read out its memory contents as char* when hashing. - static_assert(std::is_trivial::value, "ConvParams is not a trivial type"); - size_t operator()(const ConvParams& conv_params) const { - auto ptr = reinterpret_cast(&conv_params); - uint32_t value = 0x811C9DC5; - for (int i = 0; i < static_cast(sizeof(ConvParams)); ++i) { - value ^= ptr[i]; - value *= 0x01000193; - } - return static_cast(value); - } -}; - -struct ConvParamsEqual { - // ConvParams must be a trivial type because we read out its memory contents as char* when hashing. - static_assert(std::is_trivial::value, "ConvParams is not a trivial type"); - bool operator()(const ConvParams& a, const ConvParams& b) const { - auto ptr1 = reinterpret_cast(&a); - auto ptr2 = reinterpret_cast(&b); - return memcmp(ptr1, ptr2, sizeof(ConvParams)) == 0; - } -}; - -template -struct AlgoPerfCache { - mutable std::mutex mutex; - std::unordered_map map; - - bool Find(const ConvParams& params, T_Perf* result) { - std::lock_guard guard(mutex); - auto it = map.find(params); - if (it == map.end()) { - return false; - } - *result = it->second; - return true; - } - - void Insert(const ConvParams& params, const T_Perf& algo_perf) { - std::lock_guard guard(mutex); - map[params] = algo_perf; - } -}; - -// TODO: Currently we use global AlgoPerfCache for ConvGrad only. Conv's perf cache is still per node. -// Need to apply such global cache for Conv, and move some shared code from here to conv.h/cc. -AlgoPerfCache bwd_data_algos; -AlgoPerfCache bwd_filter_algos; - -template -struct AlgoSearch {}; - -template <> -struct AlgoSearch { - static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdDataAlgoGEMM; - static AlgoPerfCache& Cache() { return bwd_data_algos; } - static Status FindAlgorithms(const ConvArgs& args, const ROCMExecutionProvider* provider, const AllocatorPtr& allocator, - std::vector& perf_results) { - static const T_BwdDataAlgo algos[] = { - miopenConvolutionBwdDataAlgoGEMM, - miopenConvolutionBwdDataAlgoDirect, - miopenConvolutionBwdDataAlgoFFT, - miopenConvolutionBwdDataAlgoWinograd, - miopenTransposeBwdDataAlgoGEMM, - miopenConvolutionBwdDataAlgoImplicitGEMM}; - static constexpr int num_algos = MIOPEN_CONVOLUTION_BWD_DATA_ALGO_COUNT; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing MIOpen convolution backward data algorithms."); - int perf_count; - std::unique_ptr candidates = std::make_unique(num_algos); - size_t max_workspace_size = provider->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) - : AlgoSearchWorkspaceSize; - // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. - // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. - IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); - MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm( - args.handle, args.y_tensor, args.dy_data, args.w_desc, args.w_data, args.conv_desc, args.x_tensor, - args.dx_data, 1, &perf_count, candidates.get(), workspace.get(), max_workspace_size, false)); - perf_results = GetValidAlgorithms(candidates.get(), perf_count); - return Status::OK(); - } -}; - -template <> -struct AlgoSearch { - static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdWeightsAlgoGEMM; - static AlgoPerfCache& Cache() { return bwd_filter_algos; } - static Status FindAlgorithms(const ConvArgs& args, const ROCMExecutionProvider* provider, const AllocatorPtr& allocator, - std::vector& perf_results) { - static const T_BwdFilterAlgo algos[] = { - miopenConvolutionBwdWeightsAlgoGEMM, - miopenConvolutionBwdWeightsAlgoDirect, - miopenConvolutionBwdWeightsAlgoWinograd, - miopenConvolutionBwdWeightsAlgoImplicitGEMM}; - - static constexpr int num_algos = MIOPEN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing MIOpen convolution backward filter algorithms."); - std::unique_ptr candidates = std::make_unique(num_algos); - int perf_count; - size_t max_workspace_size = provider->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) - : AlgoSearchWorkspaceSize; - // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. - // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. - IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); - MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardWeightsAlgorithm( - args.handle, args.y_tensor, args.dy_data, args.x_tensor, args.x_data, args.conv_desc, args.w_desc, - args.dw_data, 1, &perf_count, candidates.get(), workspace.get(), max_workspace_size, false)); - perf_results = GetValidAlgorithms(candidates.get(), perf_count); - return Status::OK(); - } -}; - -template -class AlgoIterator { - public: - AlgoIterator(const ConvArgs& args) : args_(args) {} - - Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results); - - Status TryAll(const ROCMExecutionProvider* provider, const AllocatorPtr& allocator, std::function f) { - auto& cache = AlgoSearch::Cache(); - miopenConvAlgoPerf_t algo_perf; - if (cache.Find(args_.params, &algo_perf) && f(algo_perf) == Status::OK()) { - return Status::OK(); - } - - std::vector perf_results; - ORT_RETURN_IF_ERROR(AlgoSearch::FindAlgorithms(args_, provider, allocator, perf_results)); - for (auto& algo_perf : perf_results) { - if (f(algo_perf) == Status::OK()) { - cache.Insert(args_.params, algo_perf); - return Status::OK(); - } - } - ORT_ENFORCE(false, "Unable to find a valid MIOpen algorithm to run convolution."); - return Status::OK(); - } - - private: - const ConvArgs& args_; -}; - -template <> -Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { - perf_results.resize(1); - perf_results[0].bwd_data_algo = AlgoSearch::DEFAULT_ALGO; - MIOPEN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].bwd_data_algo, &(perf_results[0].memory))); - return Status::OK(); -} - -template <> -Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { - perf_results.resize(1); - perf_results[0].bwd_weights_algo = AlgoSearch::DEFAULT_ALGO; - MIOPEN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].bwd_weights_algo, &(perf_results[0].memory))); - return Status::OK(); -} - -template -Status ConvGrad::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w, Tensor* dB, Tensor* dX, - Tensor* dW, miopenHandle_t miopen_handle) const { - const TensorShape& x_shape = x.Shape(); - auto x_dims = x_shape.AsShapeVector(); - args_.x_data = reinterpret_cast(x.template Data()); - - const TensorShape& dy_shape = dY.Shape(); - auto dy_dims = dy_shape.AsShapeVector(); - args_.dy_data = reinterpret_cast(dY.template Data()); - - const TensorShape& w_shape = w.Shape(); - auto w_dims = w_shape.AsShapeVector(); - args_.w_data = reinterpret_cast(w.template Data()); - - args_.db_data = dB ? reinterpret_cast(dB->template MutableData()) : nullptr; - args_.dx_data = dX ? reinterpret_cast(dX->template MutableData()) : nullptr; - args_.dw_data = dW ? reinterpret_cast(dW->template MutableData()) : nullptr; - - bool x_dims_changed = (args_.last_x_dims != x_dims); - bool w_dims_changed = (args_.last_w_dims != w_dims); - if (x_dims_changed || w_dims_changed) { - if (x_dims_changed) args_.last_x_dims = x_dims; - if (w_dims_changed) args_.last_w_dims = w_dims; - - // Update Attributes - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&x, &w)); - - TensorShapeVector kernel_shape; - ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape)); - auto rank = kernel_shape.size(); - - ConvAttributes::ConvPadVector pads(conv_attrs_.pads); - if (pads.empty()) { - pads.resize(rank * 2, 0); - } - - TensorShapeVector dilations(conv_attrs_.dilations); - if (dilations.empty()) { - dilations.resize(rank, 1); - } - - TensorShapeVector strides(conv_attrs_.strides); - if (strides.empty()) { - strides.resize(rank, 1); - } - - // MIOpen only takes 4D or 5D x tensor, so pad dimensions if needed. - if (rank < 2) { - x_dims.push_back(1); - dy_dims.push_back(1); - w_dims.push_back(1); - pads.insert(pads.begin() + rank, 0); - pads.insert(pads.end(), 0); - kernel_shape.push_back(1); - strides.push_back(1); - dilations.push_back(1); - } - - const ROCMExecutionProvider* rocm_ep = - static_cast(this->Info().GetExecutionProvider()); - memset(&args_.params, 0, sizeof(ConvParams)); - args_.params.device_id = static_cast(rocm_ep->GetDeviceId()); - args_.params.data_type = MiopenTensor::GetDataType(); - args_.params.input_dim = static_cast(x_dims.size()); - for (size_t i = 0; i < x_dims.size(); i++) { - args_.params.input_size[i] = static_cast(x_dims[i]); - args_.params.weight_size[i] = static_cast(w_dims[i]); - } - for (size_t i = 0; i < rank; i++) { - args_.params.padding[i] = static_cast(pads[i]); - args_.params.padding[i + rank] = static_cast(pads[i + rank]); - args_.params.stride[i] = static_cast(strides[i]); - args_.params.dilation[i] = static_cast(dilations[i]); - } - args_.params.groups = conv_attrs_.group; - args_.handle = miopen_handle; - ORT_RETURN_IF_ERROR(args_.w_desc.Set(w_dims, args_.params.data_type)); - ORT_RETURN_IF_ERROR(args_.x_tensor.Set(x_dims, args_.params.data_type)); - ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type)); - ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, - gsl::narrow_cast(conv_attrs_.group), miopenConvolution, - args_.params.data_type)); - - if (dB) { - const TensorShape& db_shape = dB->Shape(); - ORT_RETURN_IF_NOT(db_shape.NumDimensions() == 1, "bias should be 1D"); - TensorShapeVector db_dims(2 + kernel_shape.size(), 1); - db_dims[1] = db_shape[0]; - ORT_RETURN_IF_ERROR(args_.b_tensor.Set(db_dims, MiopenTensor::GetDataType())); - } - } - - return Status::OK(); -} - -template -Status ConvGrad::ComputeInternal(OpKernelContext* context) const { - const Tensor* dY = context->Input(0); - const Tensor* X = context->Input(1); - const Tensor* W = context->Input(2); - Tensor* dX = context->Output(0, X->Shape()); - Tensor* dW = context->Output(1, W->Shape()); - Tensor* dB = context->Output(2, {W->Shape()[0]}); - ORT_RETURN_IF_ERROR(PrepareArgs(*X, *dY, *W, dB, dX, dW, GetMiopenHandle(context))); - if (dX) ORT_RETURN_IF_ERROR(ComputeInputGradient(context->GetComputeStream())); - if (dW) ORT_RETURN_IF_ERROR(ComputeWeightGradient(context->GetComputeStream())); - if (dB) ORT_RETURN_IF_ERROR(ComputeBiasGradient()); - return Status::OK(); -} - -template -Status ConvGrad::ComputeInputGradient(onnxruntime::Stream* stream) const { - return AlgoIterator(args_).TryAll( - static_cast(Info().GetExecutionProvider()), - Info().GetAllocator(OrtMemType::OrtMemTypeDefault), - [&](const T_BwdDataPerf& algo_perf) -> Status { - const auto one = Consts::One; - const auto zero = Consts::Zero; - IAllocatorUniquePtr workspace = GetScratchBuffer(algo_perf.memory, stream); - MIOPEN_RETURN_IF_ERROR(miopenConvolutionBackwardData( - args_.handle, &one, args_.y_tensor, args_.dy_data, args_.w_desc, args_.w_data, args_.conv_desc, - algo_perf.bwd_data_algo, &zero, args_.x_tensor, args_.dx_data, workspace.get(), algo_perf.memory)); - return Status::OK(); - }); -} - -template -Status ConvGrad::ComputeWeightGradient(onnxruntime::Stream* stream) const { - return AlgoIterator(args_).TryAll( - static_cast(Info().GetExecutionProvider()), - Info().GetAllocator(OrtMemType::OrtMemTypeDefault), - [&](const T_BwdFilterPerf& algo_perf) -> Status { - const auto one = Consts::One; - const auto zero = Consts::Zero; - IAllocatorUniquePtr workspace = GetScratchBuffer(algo_perf.memory, stream); - MIOPEN_RETURN_IF_ERROR(miopenConvolutionBackwardWeights( - args_.handle, &one, args_.y_tensor, args_.dy_data, args_.x_tensor, args_.x_data, args_.conv_desc, - algo_perf.bwd_weights_algo, &zero, args_.w_desc, args_.dw_data, workspace.get(), algo_perf.memory)); - return Status::OK(); - }); -} - -template -Status ConvGrad::ComputeBiasGradient() const { - const auto one = Consts::One; - const auto zero = Consts::Zero; - MIOPEN_RETURN_IF_ERROR(miopenConvolutionBackwardBias( - args_.handle, &one, args_.y_tensor, args_.dy_data, &zero, - args_.b_tensor, args_.db_data)); - return Status::OK(); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/nn/conv_grad.h b/orttraining/orttraining/training_ops/rocm/nn/conv_grad.h deleted file mode 100644 index d1f84c259a66a..0000000000000 --- a/orttraining/orttraining/training_ops/rocm/nn/conv_grad.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/miopen_common.h" -#include "core/providers/cpu/nn/conv_attributes.h" -#include "core/providers/rocm/nn/conv.h" - -namespace onnxruntime { -namespace rocm { - -constexpr int MAX_DIM = 3; - -struct ConvParams { - int8_t device_id; - miopenDataType_t data_type; - int input_size[2 + MAX_DIM]; - uint8_t input_dim; - int weight_size[2 + MAX_DIM]; - int padding[MAX_DIM * 2]; - int stride[MAX_DIM]; - int dilation[MAX_DIM]; - int64_t groups; -}; - -struct ConvArgs { - // Update needed if x or w's dims changed. - TensorShapeVector last_x_dims; - TensorShapeVector last_w_dims; - - miopenHandle_t handle; - ConvParams params; - MiopenTensor x_tensor, y_tensor, b_tensor; - MiopenTensorDescriptor w_desc; - MiopenConvolutionDescriptor conv_desc; - const void* x_data; - const void* w_data; - const void* dy_data; - void* dx_data; - void* dw_data; - void* db_data; -}; - -template -class ConvGrad final : public RocmKernel { - public: - using HipT = typename ToHipType::MappedType; - - ConvGrad(const OpKernelInfo& info) : RocmKernel(info), conv_attrs_(info) { - auto pads_size = conv_attrs_.pads.size(); - ORT_ENFORCE(pads_size % 2 == 0); - } - - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - Status PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w, Tensor* dB, Tensor* dX, Tensor* dW, miopenHandle_t miopen_handle) const; - mutable ConvArgs args_; - ConvAttributes conv_attrs_; - - private: - Status ComputeWeightGradient(onnxruntime::Stream* stream) const; - Status ComputeInputGradient(onnxruntime::Stream* stream) const; - Status ComputeBiasGradient() const; -}; - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/reduction/reduction_all.cc b/orttraining/orttraining/training_ops/rocm/reduction/reduction_all.cc deleted file mode 100644 index 093a516ce8241..0000000000000 --- a/orttraining/orttraining/training_ops/rocm/reduction/reduction_all.cc +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/rocm/reduction/reduction_all.h" -#include "orttraining/training_ops/rocm/reduction/reduction_all_impl.h" - -#include "core/providers/rocm/reduction/reduction_functions.h" -#include "core/providers/rocm/shared_inc/accumulation_type.h" - -namespace onnxruntime { -namespace rocm { - -#define REGISTER_REDUCE_ALL_KERNEL_TYPED(Name, TIn, TOut) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Name, \ - kMSDomain, \ - 1, \ - TIn##_##TOut, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("TIn", DataTypeImpl::GetTensorType()).TypeConstraint("TOut", DataTypeImpl::GetTensorType()), \ - Name); - -template -Status ReduceAllL2::ComputeInternal(OpKernelContext* ctx) const { - typedef typename ToHipType::MappedType HipTIn; - typedef typename ToHipType::MappedType HipTOut; - - // Get Input tensor count. - const auto total_tensor_count = ctx->InputCount(); - // We only have one tensor per group so - // grouped_tensor_pointers[i] always contains only one element. - std::vector> grouped_tensor_pointers(total_tensor_count); - std::vector tensor_sizes(total_tensor_count); - - for (int i = 0; i < total_tensor_count; ++i) { - const Tensor* input = ctx->Input(i); - const auto size = input->Shape().Size(); - ORT_ENFORCE(size <= std::numeric_limits::max(), "Number of reduced elements (", - size, ") exceeds the max allowed value (", std::numeric_limits::max(), ")."); - grouped_tensor_pointers[i] = {const_cast(input->Data())}; - tensor_sizes[i] = static_cast(size); - } - - // Allocate output tensor. - Tensor* output = ctx->Output(0, {}); - HipTOut* p_output = reinterpret_cast(output->template MutableData()); - HIP_RETURN_IF_ERROR(hipMemsetAsync(p_output, 0, sizeof(HipTOut), Stream(ctx))); - - // const bool deterministic = ctx->GetUseDeterministicCompute(); - bool deterministic = true; - - if (!deterministic) { - typedef MultiTensorReduceL2 TFunctor; - TFunctor functor; - - // Check if all values are finite and write true to deviceOutput. - // Otherwise, false will be written. - launch_multi_tensor_functor<1, TFunctor>(Stream(ctx), - 2048 * 32, tensor_sizes, grouped_tensor_pointers, functor, p_output); - - // *p_output is the squared sum of all elements. - // Let's take a sqrt to get the actual L2-norm. - ScalarSqrt(Stream(ctx), p_output, p_output); - } else { - // alternate path only for deterministic compute .. - typedef AccumulationType_t HipTAcc; - - // find reduction buffer size needed by 'reduce_square_sum' for each tensor - size_t reduction_buffer_size = 0; - for (int i = 0; i < total_tensor_count; ++i) { - reduction_buffer_size = - std::max(reduction_buffer_size, compute_reduction_buffer_size(tensor_sizes[i])); - } - - // enlarge reduction buffer size for 'reduce_sum' over tensor square norms - reduction_buffer_size = - std::max(reduction_buffer_size, compute_reduction_buffer_size(total_tensor_count)); - - // create GPU scratch space and zero target for each tensor square norm - auto reduction_buffer = GetScratchBuffer(reduction_buffer_size, ctx->GetComputeStream()); - - // buffer for final output and square norms of each tensor - auto results_buffer = GetScratchBuffer(1 + total_tensor_count, ctx->GetComputeStream()); - - HIP_RETURN_IF_ERROR(hipMemsetAsync(results_buffer.get(), 0, sizeof(HipTAcc) * (1 + total_tensor_count), Stream(ctx))); - - HipTAcc* p_global_sqnorm = results_buffer.get(); - HipTAcc* p_tensor_sqnorm = p_global_sqnorm + 1; - - // perform reduction l2norm = sqrt[sum(tensor[i][j]**2)] for i,j over all tensor elements - for (int i = 0; i < total_tensor_count; ++i) { - HipTIn* p_tensor_i = reinterpret_cast(grouped_tensor_pointers[i][0]); - ORT_RETURN_IF_ERROR(reduce_square_sum( - Stream(ctx), p_tensor_i, p_tensor_sqnorm + i, tensor_sizes[i], reduction_buffer.get(), reduction_buffer_size)); - } - ORT_RETURN_IF_ERROR(reduce_sum( - Stream(ctx), p_tensor_sqnorm, p_global_sqnorm, total_tensor_count, reduction_buffer.get(), reduction_buffer_size)); - ScalarSqrt(Stream(ctx), p_global_sqnorm, p_output); - } - - return Status::OK(); -} - -REGISTER_REDUCE_ALL_KERNEL_TYPED(ReduceAllL2, float, float) -REGISTER_REDUCE_ALL_KERNEL_TYPED(ReduceAllL2, MLFloat16, float) -REGISTER_REDUCE_ALL_KERNEL_TYPED(ReduceAllL2, float, MLFloat16) -REGISTER_REDUCE_ALL_KERNEL_TYPED(ReduceAllL2, MLFloat16, MLFloat16) -REGISTER_REDUCE_ALL_KERNEL_TYPED(ReduceAllL2, BFloat16, float) -REGISTER_REDUCE_ALL_KERNEL_TYPED(ReduceAllL2, float, BFloat16) -REGISTER_REDUCE_ALL_KERNEL_TYPED(ReduceAllL2, BFloat16, BFloat16) - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/reduction/reduction_ops.cc b/orttraining/orttraining/training_ops/rocm/reduction/reduction_ops.cc deleted file mode 100644 index 23811744885e0..0000000000000 --- a/orttraining/orttraining/training_ops/rocm/reduction/reduction_ops.cc +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/rocm/reduction/reduction_ops.h" -#include "core/providers/common.h" -#include "core/providers/rocm/miopen_common.h" -#include "core/providers/rocm/math/unary_elementwise_ops_impl.h" -#include "core/providers/rocm/math/binary_elementwise_ops_impl.h" -#include "core/providers/rocm/math/binary_elementwise_ops.h" -#include "core/providers/cpu/tensor/utils.h" - -using namespace onnxruntime::common; -namespace onnxruntime { -namespace rocm { - -#define REGISTER_MS_KERNEL_TYPED(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .InputMemoryType(OrtMemTypeCPUInput, 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, MLFloat16) -REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, float) -// REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, double) -REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, int32_t) - -template -template -Status ReduceKernel::ComputeImplEx(OpKernelContext* ctx, miopenReduceTensorOp_t miopen_reduce_op) const { - const Tensor* X = ctx->Input(0); - - // override the attribute value with the input value for reduction_axes - const Tensor* axes_tensor = ctx->Input(1); - ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); - auto nDims = static_cast(axes_tensor->Shape()[0]); - const auto* data = axes_tensor->template Data(); - std::vector axes(data, data + nDims); - - // empty axes and no-op - if (axes.empty() && noop_with_empty_axes_) { - auto* Y = ctx->Output(0, X->Shape()); - HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->template MutableData(), X->template Data(), X->SizeInBytes(), hipMemcpyDeviceToDevice, Stream(ctx))); - return Status::OK(); - } - - PrepareReduceMetadata prepare_reduce_metadata; - ORT_RETURN_IF_ERROR(PrepareForReduce(X, - keepdims_, - axes, - prepare_reduce_metadata)); - Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); - const bool fast_reduction = fast_reduction_ && !ctx->GetUseDeterministicCompute(); - - return ReduceComputeCore(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), *X, prepare_reduce_metadata, *Y, miopen_reduce_op, axes, - calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, ctx->GetComputeStream()); -} - -template <> -template <> -Status ReduceKernel::ComputeImplEx(OpKernelContext* ctx, miopenReduceTensorOp_t miopen_reduce_op) const { - typedef typename ToHipType::MappedType HipT; - - const Tensor* X = ctx->Input(0); - - // override the attribute value with the input value for reduction_axes - const Tensor* axes_tensor = ctx->Input(1); - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); - auto nDims = static_cast(axes_tensor->Shape()[0]); - const auto* data = axes_tensor->template Data(); - std::vector axes(data, data + nDims); - - // empty axes and no-op - if (axes.empty() && noop_with_empty_axes_) { - auto* Y = ctx->Output(0, X->Shape()); - HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->template MutableData(), X->template Data(), X->SizeInBytes(), hipMemcpyDeviceToDevice, Stream(ctx))); - return Status::OK(); - } - - PrepareReduceMetadata prepare_reduce_metadata; - - ORT_RETURN_IF_ERROR(PrepareForReduce(X, - keepdims_, - axes, - prepare_reduce_metadata)); - - Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); - - int64_t input_count = prepare_reduce_metadata.input_count; - int64_t output_count = prepare_reduce_metadata.output_count; - auto& input_dims_miopen = prepare_reduce_metadata.input_dims_miopen; - auto& output_dims_miopen = prepare_reduce_metadata.output_dims_miopen; - - // special case when there is a dim value of 0 in the shape. - if (input_count == 0) { - assert(Y->Shape().Size() == 0); - return Status::OK(); - } - - // miopenReduceTensor for ReduceSum has issue if input and output has same size, we just need to copy the data for this case - if (input_count == output_count) { - if (Y->template MutableData() != X->template Data()) { - HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->template MutableData(), X->template Data(), input_count * sizeof(int32_t), hipMemcpyDeviceToDevice, Stream(ctx))); - } - return Status::OK(); - } - - // This reduction keep adding values to this buffer. If a non-zero value, say 1000, is here, the sum will start with 1000. - // Therefore zeroing out the memory is required - HIP_RETURN_IF_ERROR(hipMemsetAsync(Y->MutableDataRaw(), 0, Y->SizeInBytes(), Stream(ctx))); - - size_t indices_bytes = 0; - size_t workspace_bytes = 0; - MiopenTensor input_tensor; - MiopenTensor output_tensor; - MiopenReduceDescriptor reduce_desc; - - miopenDataType_t miopen_type_X = miopenFloat; - IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, ctx->GetComputeStream()); - Impl_Cast(Stream(ctx), reinterpret_cast(X->template Data()), temp_X.get(), X->Shape().Size()); - - ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, miopen_type_X, MIOPEN_REDUCE_TENSOR_FLATTENED_INDICES)); - ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X)); - ORT_RETURN_IF_ERROR(output_tensor.Set(output_dims_miopen, miopen_type_X)); - MIOPEN_RETURN_IF_ERROR(miopenGetReductionIndicesSize(GetMiopenHandle(ctx), reduce_desc, input_tensor, output_tensor, &indices_bytes)); - MIOPEN_RETURN_IF_ERROR(miopenGetReductionWorkspaceSize(GetMiopenHandle(ctx), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); - IAllocatorUniquePtr indices_rocm = GetScratchBuffer(indices_bytes, ctx->GetComputeStream()); - IAllocatorUniquePtr workspace_rocm = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); - - const auto one = Consts::One; - const auto zero = Consts::Zero; - auto temp_Y = GetScratchBuffer(output_count, ctx->GetComputeStream()); - MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(GetMiopenHandle(ctx), - reduce_desc, - indices_rocm.get(), - indices_bytes, - workspace_rocm.get(), - workspace_bytes, - &one, - input_tensor, - temp_X.get(), - &zero, - output_tensor, - temp_Y.get())); - - Impl_Cast(Stream(ctx), temp_Y.get(), Y->template MutableData(), output_count); - - return Status::OK(); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc deleted file mode 100644 index c570f727f2a92..0000000000000 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ /dev/null @@ -1,437 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_fwd.h" -#include "core/providers/rocm/rocm_pch.h" - -using namespace onnxruntime::common; - -namespace onnxruntime { -namespace rocm { - -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, View); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Group); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PassThrough); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SGDOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ReduceSumTraining); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ReduceSumTraining); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int32_t, ReduceSumTraining); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ReduceSumTraining); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SplitTraining); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ConcatTraining); - -// Adam -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int64_t_float_float_float_float_MLFloat16, AdamOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int64_t_float_MLFloat16_float_float_MLFloat16, AdamOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int64_t_float_MLFloat16_float_float_MLFloat16, AdamOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int64_t_float_float_MLFloat16_MLFloat16_MLFloat16, AdamOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int64_t_float_float_MLFloat16_float_MLFloat16, AdamOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int64_t_float_MLFloat16_MLFloat16_MLFloat16_MLFloat16, AdamOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int64_t_float_MLFloat16_MLFloat16_float_MLFloat16, AdamOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int64_t_float_MLFloat16_MLFloat16_MLFloat16_MLFloat16, AdamOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int64_t_float_MLFloat16_MLFloat16_float_MLFloat16, AdamOptimizer); -// Lamb -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float_float_float_MLFloat16, LambOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_MLFloat16_float_MLFloat16_MLFloat16, LambOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_MLFloat16_float_float_MLFloat16, LambOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double_double_double_MLFloat16, LambOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_MLFloat16_MLFloat16_MLFloat16, LambOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_MLFloat16_float_MLFloat16, LambOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_float_MLFloat16_MLFloat16, LambOptimizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_float_float_MLFloat16, LambOptimizer); -// Gradient accumulator -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float, InPlaceAccumulator); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_MLFloat16, InPlaceAccumulator); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, InPlaceAccumulator); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float, InPlaceAccumulator); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ZeroGradient); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ZeroGradient); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SoftmaxCrossEntropy); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SoftmaxCrossEntropyGrad); -// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, int32_t, SparseSoftmaxCrossEntropy); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, int64_t, SparseSoftmaxCrossEntropy); -// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, int32_t, SparseSoftmaxCrossEntropyGrad); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, int64_t, SparseSoftmaxCrossEntropyGrad); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, int64_t, SoftmaxCrossEntropyLoss); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, int64_t, SoftmaxCrossEntropyLoss); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, int64_t, SoftmaxCrossEntropyLoss); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, int64_t, SoftmaxCrossEntropyLoss); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, int64_t, SoftmaxCrossEntropyLoss); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossGrad); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossGrad); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, int64_t, SoftmaxCrossEntropyLossGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int64_t_float, SoftmaxCrossEntropyLossInternal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int64_t_float, SoftmaxCrossEntropyLossInternal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int64_t_MLFloat16, SoftmaxCrossEntropyLossInternal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_int64_t_BFloat16, SoftmaxCrossEntropyLossInternal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int64_t_MLFloat16, SoftmaxCrossEntropyLossInternalGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int64_t_float, SoftmaxCrossEntropyLossInternalGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int64_t_MLFloat16, SoftmaxCrossEntropyLossInternalGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_int64_t_BFloat16, SoftmaxCrossEntropyLossInternalGrad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SoftmaxGrad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, LogSoftmaxGrad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SoftmaxGrad_13); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, LogSoftmaxGrad_13); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float, BatchNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double, BatchNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_float, BatchNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, BatchNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvGrad); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ConvGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherGrad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DropoutGrad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskDropoutGrad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmaxDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SoftmaxDropoutGrad); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int64_t, GatherNDGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DivGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, DivGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DivGrad); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GeluGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, GeluGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GeluGrad); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGeluGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGeluGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGeluGrad); - -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGeluGrad_dX); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasFastGeluGrad_dX); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ReluGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ReluGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ReluGrad); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SigmoidGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, SigmoidGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SigmoidGrad); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGeluGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGeluGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TanhGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TanhGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LeakyReluGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, LeakyReluGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LeakyReluGrad); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, IsFinite); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, IsFinite); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, bool, All); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, IsAllFinite); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, IsAllFinite); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, IsAllFinite); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MixedPrecisionScale); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MixedPrecisionScale); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float, ReduceAllL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float, ReduceAllL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_MLFloat16, ReduceAllL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, ReduceAllL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float, LayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double, LayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16, LayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_MLFloat16, LayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, LayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float, SimplifiedLayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double, SimplifiedLayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float, InvertibleLayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double, InvertibleLayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16, InvertibleLayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_MLFloat16, InvertibleLayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, InvertibleLayerNormalizationGrad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SliceGrad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherElementsGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Scale); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Scale); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Scale); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Scale); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistBinarizeEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, GistBinarizeEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistBinarizeDecoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeDecoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, GistBinarizeDecoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, bool, GistPack1Encoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistPack1Encoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, bool, GistPack1Decoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistPack1Decoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistPack8Encoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Encoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistPack8Decoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Decoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistPack16Encoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistPack16Decoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Encoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Decoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float, BatchNormInternal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double, BatchNormInternal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormInternal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_float, BatchNormInternal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, BatchNormInternal); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, MixedPrecisionScale); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, LayerNormalizationGrad); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, FlattenAndUnpad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ResizeGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ResizeGrad); - -#if defined(ORT_USE_NCCL) || defined(USE_MPI) -// P2P communication operators. -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Send); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Recv); -#endif - -#ifdef USE_MPI -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AdasumAllReduce); -#endif - -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, RecordEvent); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, WaitEvent); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, YieldOp); - -#ifdef ENABLE_TRAINING_TORCH_INTEROP -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PythonOp); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PythonOpGrad); -#endif - -#ifdef ORT_USE_NCCL -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, NcclAllReduce); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, NcclAllGather); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, NcclReduceScatter); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MegatronF); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MegatronG); -#endif - -Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { - static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // Adam - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // Lamb - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - -// P2P communication operators. -#if defined(ORT_USE_NCCL) || defined(USE_MPI) - BuildKernelCreateInfo, - BuildKernelCreateInfo, -#endif - -#ifdef USE_MPI - // BuildKernelCreateInfo, -#endif - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - -#ifdef ENABLE_TRAINING_TORCH_INTEROP - BuildKernelCreateInfo, - BuildKernelCreateInfo, -#endif - -#ifdef ORT_USE_NCCL - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, -#endif - }; - - for (auto& function_table_entry : function_table) { - ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry())); - } - - return Status::OK(); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.h b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.h deleted file mode 100644 index 697975b7f3409..0000000000000 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.h +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace onnxruntime { -namespace rocm { - -Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry); - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch b/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch deleted file mode 100644 index 29b8812c979e4..0000000000000 --- a/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch +++ /dev/null @@ -1,170 +0,0 @@ -# docker build --network=host --file Dockerfile.rocm4.3.1.pytorch --tag ort:rocm4.3.1-pytorch . - -FROM rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 - -RUN apt-get -y install gpg-agent -RUN wget -q -O - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - -RUN echo 'deb [arch=amd64] http://repo.radeon.com/rocm/apt/4.3.1 xenial main' | tee /etc/apt/sources.list.d/rocm.list - -RUN apt-get -y update -RUN apt-get -y install apt-utils -RUN apt-get -y install build-essential autotools-dev \ - make git curl vim wget rsync jq openssh-server openssh-client sudo \ - iputils-ping net-tools ethtool libcap2 \ - automake autoconf libtool flex doxygen \ - perl lsb-release iproute2 pciutils graphviz \ - bc tar git bash pbzip2 pv bzip2 unzip cabextract \ - g++ gcc \ - && apt-get autoremove - -# sh -RUN rm /bin/sh && ln -s /bin/bash /bin/sh - -# Labels for the docker -LABEL description="This docker sets up the environment to run ORT Training with AMD GPU" - -# CMake -ENV CMAKE_VERSION=3.18.2 -RUN cd /usr/local && \ - wget -q -O - https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-Linux-x86_64.tar.gz | tar zxf - -ENV PATH=/usr/local/cmake-${CMAKE_VERSION}-Linux-x86_64/bin:${PATH} - -ENV WORKSPACE_DIR=/workspace -RUN mkdir -p $WORKSPACE_DIR -WORKDIR $WORKSPACE_DIR - -ENV OLD_PATH=${PATH} -ENV PATH=/usr/bin:${PATH} -# Infiniband setup, openmpi installed under /usr/mpi/gcc/openmpi-4.0.4rc3 doesn't support multi-thread -ENV MOFED_VERSION=5.1-0.6.6.0 -ENV MOFED_OS=ubuntu18.04 -ENV MOFED_FILENAME=MLNX_OFED_LINUX-${MOFED_VERSION}-${MOFED_OS}-x86_64 -RUN curl -fSsL https://www.mellanox.com/downloads/ofed/MLNX_OFED-${MOFED_VERSION}/${MOFED_FILENAME}.tgz | tar -zxpf - -RUN cd MLNX_OFED_LINUX-${MOFED_VERSION}-${MOFED_OS}-x86_64 && \ - ./mlnxofedinstall --force --user-space-only --without-fw-update --hpc && \ - cd .. && \ - rm -r MLNX_OFED_LINUX-${MOFED_VERSION}-${MOFED_OS}-x86_64 - -ENV PATH=${OLD_PATH} -ENV unset=OLD_PATH - -# python env -RUN pip3 install --upgrade setuptools -ARG NUMPY_VERSION=1.18.5 -ARG ONNX_VERSION=1.10.2 -RUN pip3 install --no-cache-dir wheel tqdm boto3 requests six ipdb h5py html2text nltk progressbar pyyaml \ - git+https://github.com/NVIDIA/dllogger \ - numpy==${NUMPY_VERSION} \ - onnx=="${ONNX_VERSION}" - -ENV GITHUB_DIR=$WORKSPACE_DIR/github -RUN mkdir -p $GITHUB_DIR - -# UCX -WORKDIR $GITHUB_DIR -RUN apt-get -y update && apt-get -y --no-install-recommends install libnuma-dev -ARG UCX_VERSION=1.9.0-rc3 -ENV UCX_DIR=$WORKSPACE_DIR/ucx-$UCX_VERSION -RUN git clone https://github.com/openucx/ucx.git \ - && cd ucx \ - && git checkout v$UCX_VERSION \ - && ./autogen.sh \ - && mkdir build \ - && cd build \ - && ../contrib/configure-opt --prefix=$UCX_DIR --without-rocm --without-knem --without-cuda \ - && make -j"$(nproc)" \ - && make install \ - && cd .. \ - && rm -rf build - -# OpenMPI -# note: require --enable-orterun-prefix-by-default for Azure machine learning compute -# note: disable verbs as we use ucx middleware and don't want btl openib warnings -WORKDIR $GITHUB_DIR -ARG OPENMPI_BASEVERSION=4.0 -ARG OPENMPI_VERSION=${OPENMPI_BASEVERSION}.5 -ENV OPENMPI_DIR=$WORKSPACE_DIR/openmpi-${OPENMPI_VERSION} -RUN git clone --recursive https://github.com/open-mpi/ompi.git \ - && cd ompi \ - && git checkout v$OPENMPI_VERSION \ - && ./autogen.pl \ - && mkdir build \ - && cd build \ - && ../configure --prefix=$OPENMPI_DIR --with-ucx=$UCX_DIR --without-verbs \ - --enable-mpirun-prefix-by-default --enable-orterun-prefix-by-default \ - --enable-mca-no-build=btl-uct --disable-mpi-fortran \ - && make -j"$(nproc)" \ - && make install \ - && cd .. \ - && rm -rf build \ - && ldconfig \ - && test -f ${OPENMPI_DIR}/bin/mpic++ - -ENV PATH=$OPENMPI_DIR/bin:${PATH} -ENV LD_LIBRARY_PATH=$OPENMPI_DIR/lib:${LD_LIBRARY_PATH} - -# Create a wrapper for OpenMPI to allow running as root by default -RUN mv $OPENMPI_DIR/bin/mpirun $OPENMPI_DIR/bin/mpirun.real && \ - echo '#!/bin/bash' > $OPENMPI_DIR/bin/mpirun && \ - echo 'mpirun.real --allow-run-as-root "$@"' >> $OPENMPI_DIR/bin/mpirun && \ - chmod a+x $OPENMPI_DIR/bin/mpirun - -# install mpi4py (be sure to link existing /opt/openmpi-xxx) -RUN CC=mpicc MPICC=mpicc pip install mpi4py --no-binary mpi4py - -ARG CACHE_DATA=2021-10-25 - -# ONNX Runtime -WORKDIR $GITHUB_DIR -ENV ORT_DIR=$GITHUB_DIR/onnxruntime -RUN git clone -b wezhan/tnlrv4 --recursive https://github.com/microsoft/onnxruntime.git \ - && cd onnxruntime \ - && python3 tools/ci_build/build.py \ - --cmake_extra_defines ONNXRUNTIME_VERSION=`cat ./VERSION_NUMBER` \ - --build_dir build \ - --config Release \ - --parallel \ - --skip_tests \ - --build_wheel \ - --use_rocm --rocm_version=4.3.1 --rocm_home /opt/rocm \ - --mpi_home $OPENMPI_DIR \ - --nccl_home /opt/rocm \ - --enable_training \ - && test -f $ORT_DIR/build/Release/onnxruntime_training_bert \ - && pip install $ORT_DIR/build/Release/dist/*.whl \ - && ldconfig - -RUN pip3 install --no-cache-dir GPUtil azureml azureml-core datasets tokenizers ninja cerberus sympy sacremoses sacrebleu - -RUN pip install transformers==2.10.0 scikit-learn tensorboardX -RUN pip install --pre torch-ort -f https://download.onnxruntime.ai/torch_ort_nightly.html -RUN python -m torch_ort.configure - -# Enable ssh access without password needed -RUN sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/g' /etc/ssh/sshd_config -RUN sed -i 's/#StrictModes yes/StrictModes no/g' /etc/ssh/sshd_config -RUN sed -i 's/#PubkeyAuthentication yes/PubkeyAuthentication yes/g' /etc/ssh/sshd_config -RUN sed -i 's/#PermitEmptyPasswords no/PermitEmptyPasswords yes/g' /etc/ssh/sshd_config - -# Start or Restart sshd service -ENTRYPOINT service ssh restart && /bin/bash - -# Add model and scripts -ADD script ${WORKSPACE_DIR}/script -RUN chmod a+x ${WORKSPACE_DIR}/script/run_bert.sh - -# add locale en_US.UTF-8 -RUN apt-get install -y locales -RUN locale-gen en_US.UTF-8 - -# Workaround an issue in AMD compiler which generates poor GPU ISA -# when the type of kernel parameter is a structure and “pass-by-value” is used -# ENV HSA_NO_SCRATCH_RECLAIM=1 - -# Distributed training related environment variables -ENV HSA_FORCE_FINE_GRAIN_PCIE=1 -# ENV NCCL_DEBUG=INFO -# ENV RCCL_ALLTOALL_KERNEL_DISABLE=1 -# ENV NCCL_DEBUG_SUBSYS=INIT,COLL - -WORKDIR ${WORKSPACE_DIR}/script diff --git a/orttraining/tools/amdgpu/script/rocprof.py b/orttraining/tools/amdgpu/script/rocprof.py deleted file mode 100644 index 21dd8501f3f1d..0000000000000 --- a/orttraining/tools/amdgpu/script/rocprof.py +++ /dev/null @@ -1,77 +0,0 @@ -import argparse -import csv -import os # noqa: F401 - -import numpy as np # noqa: F401 - -parser = argparse.ArgumentParser() -parser.add_argument("--input", type=str) -args = parser.parse_args() - - -def get_gpu_lines(path): - lines = [] - with open(path, newline="") as f: - reader = csv.reader(f, delimiter=",") - for row in reader: - if row[2].find("TotalDurationNs") < 0: - lines.append(row) - return lines - - -activities = [ - ("nccl", lambda x: x.find("nccl") >= 0), - ("gemm", lambda x: x.find("Cijk_") >= 0), - ("memcpy", lambda x: x.find("CUDA mem") >= 0), - ("adam", lambda x: x.lower().find("adam") >= 0), - ("lamb", lambda x: x.lower().find("lamb") >= 0 or x.lower().find("multi_tensor_apply") >= 0), - ("dropout", lambda x: x.lower().find("dropout") >= 0 or x.find("curand") >= 0), - ("layernorm", lambda x: x.find("LayerNorm") >= 0 or x.find("cuCompute") >= 0), - ("reduce", lambda x: x.find("reduce") >= 0), - ("softmax", lambda x: x.lower().find("softmax") >= 0), - ("transpose", lambda x: x.lower().find("transpose") >= 0), - ("element-wise", lambda x: x.lower().find("elementwise") >= 0 or x.find("DivGrad") >= 0), - ("jit", lambda x: x.startswith("kernel_")), - ("misc", lambda x: True), -] - - -def group_gpu_activity(lines): - groups = {name: [] for name, _ in activities} - for line in lines: - for name, check in activities: - if check(line[0]): - groups[name].append(line) - break - return groups - - -def get_seconds(time): - return float(time.replace("us", "")) / (1000.0 * 1000.0 * 1000.0) - - -def gpu_percent_time(activities): - return sum([float(a[4].replace("%", "")) for a in activities]) - - -def gpu_absolute_time(activities): - return sum([get_seconds(a[2]) for a in activities]) - - -def gpu_kernel_calls(activities): - return sum([int(a[1]) for a in activities]) - - -lines = get_gpu_lines(args.input) -groups = group_gpu_activity(lines) - -for name in groups: - activities = groups[name] - print( - f"{name}: N={len(activities)}, calls={gpu_kernel_calls(activities)}, absolute={gpu_absolute_time(activities):.3f}s, percent={gpu_percent_time(activities):.2f}%" - ) - -total = [item for name in groups for item in groups[name]] -print( - f"Total: N={len(total)}, calls={gpu_kernel_calls(total)}, absolute={gpu_absolute_time(total):.3f}s, percent={gpu_percent_time(total):.2f}%" -) diff --git a/orttraining/tools/amdgpu/script/rpl_rc.xml b/orttraining/tools/amdgpu/script/rpl_rc.xml deleted file mode 100644 index 3ca51072b6c98..0000000000000 --- a/orttraining/tools/amdgpu/script/rpl_rc.xml +++ /dev/null @@ -1,10 +0,0 @@ - - diff --git a/orttraining/tools/amdgpu/script/run_bert.sh b/orttraining/tools/amdgpu/script/run_bert.sh deleted file mode 100644 index 950dcaf89ff61..0000000000000 --- a/orttraining/tools/amdgpu/script/run_bert.sh +++ /dev/null @@ -1,84 +0,0 @@ -if [ "$#" -ne 12 ]; then - echo "Usage: $0 ngpu batch_size seq_len num_train_steps optimizer model_size training_mode[fp32|fp16] display_loss_steps gradient_accumulation_steps loss_scale gpu_name profile" - exit 1 -fi - -ngpu=${1:-1} -batch_size=${2:-64} -seq_len=${3:-128} - -if [ ${seq_len} == 128 ]; then - max_predictions_per_seq=20 -elif [ ${seq_len} == 512 ]; then - max_predictions_per_seq=80 -else - echo "seq_len is not 128 or 512" - exit 1 -fi - -num_train_steps=${4:-400} -optimizer=${5:-"adam"} -model_size=${6:-"large"} -training_mode=${7:-"fp32"} -display_loss_steps=${8:-1} -grad_acc=${9:-1} -loss_scale=${10:-1024} -gpu_name=${11:-"mi100"} -profile=${12:-0} - -lr=5e-5 -warmup_ratio=0.2843 -warmup_mode=Poly -effective_batch_size=$((ngpu * batch_size * grad_acc)) -time_now=$(date +%m%d%H%M) - -HOME_DIR=/workspace -ORT_DIR=${HOME_DIR}/github/onnxruntime -commit=$(git -C ${ORT_DIR} rev-parse HEAD | cut -c1-8) - -if [ ${model_size} == "large" ]; then - model_dir=${HOME_DIR}/model/bert-large-uncased_L_24_H_1024_A_16_V_30528_S_512_Dp_0.1_optimized_layer_norm_opset12 -elif [ ${model_size} == "base" ]; then - model_dir=${HOME_DIR}/model/bert-base-uncased_L_12_H_768_A_12_V_30528_S_512_Dp_0.1_optimized_layer_norm_opset12 -elif [ ${model_size} == "tiny" ]; then - model_dir=${HOME_DIR}/model/bert-tiny-uncased_L_3_H_128_A_2_V_30528_S_512_Dp_0.1_optimized_layer_norm_opset12 -else - echo "model_size is not large, base or tiny" - exit 1 -fi - -data_dir=/data/wezhan/bert/${seq_len}/train -training_bert_dir=${ORT_DIR}/build/RelWithDebInfo - -log_dir=${HOME_DIR}/logs/bert_${model_size}/$(date +%m%d) -if [ ! -d ${log_dir} ]; then - mkdir -p ${log_dir} -fi - -run_name=bert_${model_size}_${commit}_g${ngpu}_bs${batch_size}_sl${seq_len}_steps${num_train_steps}_${optimizer}_${training_mode}_acc${grad_acc}_efbs${effective_batch_size}_${time_now}_${gpu_name} - -if [ ! -d ${log_dir}/${run_name} ]; then - mkdir -p ${log_dir}/${run_name} -fi - -if [ ${ngpu} != 1 ]; then - mpi_cmd="${OPENMPI_DIR}/bin/mpirun --allow-run-as-root -n ${ngpu} -x NCCL_DEBUG=INFO -x NCCL_DEBUG_SUBSYS=INIT,COLL -x NCCL_MIN_NCHANNELS=4" -fi - -if [ ${training_mode} == "fp16" ]; then - fp16_commands="--use_mixed_precision --allreduce_in_fp16 --loss_scale ${loss_scale}" -fi - -if [ ${profile} == 1 ]; then - if [ ${gpu_name} == "mi100" ]; then - profile_commands="/opt/rocm/bin/rocprof --obj-tracking on --stats" - elif [ ${gpu_name} == "v100" ]; then - profile_commands="nvprof --print-gpu-summary --log-file ${log_dir}/${run_name}-trace.log" - fi -fi - -nohup ${profile_commands} ${mpi_cmd} ${training_bert_dir}/onnxruntime_training_bert --model_name ${model_dir} --train_data_dir ${data_dir} --test_data_dir ${data_dir} --train_batch_size ${batch_size} --mode train --num_train_steps ${num_train_steps} --optimizer ${optimizer} --learning_rate ${lr} --warmup_ratio ${warmup_ratio} --warmup_mode ${warmup_mode} --gradient_accumulation_steps ${grad_acc} --max_seq_length ${seq_len} --max_predictions_per_seq=${max_predictions_per_seq} --use_nccl --lambda 0 ${fp16_commands} --display_loss_steps ${display_loss_steps} --log_dir ${log_dir}/${run_name} > ${log_dir}/${run_name}.log 2>&1 & - -tail -f ${log_dir}/${run_name}.log - -exit 0 diff --git a/setup.py b/setup.py index f6a697b1bb2b9..dd495da56c4c3 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,6 @@ def parse_arg_remove_string(argv, arg_name_equal): cuda_version = None cuda_major_version = None -rocm_version = None is_migraphx = False is_openvino = False is_qnn = False @@ -244,7 +243,7 @@ def run(self): "libnvrtc-builtins.so.13", ] - rocm_dependencies = [ + migraphx_dependencies = [ "libamd_comgr.so.2", "libamdhip64.so.5", "libamdhip64.so.6", @@ -300,7 +299,7 @@ def run(self): file = glob(path.join(self.dist_dir, "*linux*.whl"))[0] logger.info("repairing %s for manylinux1", file) auditwheel_cmd = ["auditwheel", "-v", "repair", "-w", self.dist_dir, file] - for i in cuda_dependencies + rocm_dependencies + tensorrt_dependencies + cann_dependencies: + for i in cuda_dependencies + migraphx_dependencies + tensorrt_dependencies + cann_dependencies: auditwheel_cmd += ["--exclude", i] logger.info("Running %s", " ".join([shlex.quote(arg) for arg in auditwheel_cmd])) try: @@ -322,7 +321,7 @@ def finalize_options(self): return ret -providers_cuda_or_rocm = "onnxruntime_providers_cuda" +providers_cuda = "onnxruntime_providers_cuda" providers_tensorrt_or_migraphx = "onnxruntime_providers_" + ("migraphx" if is_migraphx else "tensorrt") providers_nv_tensorrt_rtx = "onnxruntime_providers_nv_tensorrt_rtx" providers_openvino = "onnxruntime_providers_openvino" @@ -330,14 +329,14 @@ def finalize_options(self): providers_qnn = "onnxruntime_providers_qnn" if platform.system() == "Linux": - providers_cuda_or_rocm = "lib" + providers_cuda_or_rocm + ".so" + providers_cuda = "lib" + providers_cuda + ".so" providers_tensorrt_or_migraphx = "lib" + providers_tensorrt_or_migraphx + ".so" providers_nv_tensorrt_rtx = "lib" + providers_nv_tensorrt_rtx + ".so" providers_openvino = "lib" + providers_openvino + ".so" providers_cann = "lib" + providers_cann + ".so" providers_qnn = "lib" + providers_qnn + ".so" elif platform.system() == "Windows": - providers_cuda_or_rocm = providers_cuda_or_rocm + ".dll" + providers_cuda = providers_cuda + ".dll" providers_tensorrt_or_migraphx = providers_tensorrt_or_migraphx + ".dll" providers_nv_tensorrt_rtx = providers_nv_tensorrt_rtx + ".dll" providers_openvino = providers_openvino + ".dll" @@ -359,7 +358,7 @@ def finalize_options(self): "libonnxruntime.so*", ] dl_libs = ["libonnxruntime_providers_shared.so"] - dl_libs.append(providers_cuda_or_rocm) + dl_libs.append(providers_cuda) dl_libs.append(providers_tensorrt_or_migraphx) dl_libs.append(providers_cann) dl_libs.append(providers_qnn) @@ -369,7 +368,7 @@ def finalize_options(self): libs.extend(["libonnxruntime_providers_dnnl.so"]) libs.extend(["libonnxruntime_providers_openvino.so"]) libs.extend(["libonnxruntime_providers_vitisai.so"]) - libs.append(providers_cuda_or_rocm) + libs.append(providers_cuda) libs.append(providers_nv_tensorrt_rtx) libs.append(providers_tensorrt_or_migraphx) libs.append(providers_cann) @@ -410,7 +409,7 @@ def finalize_options(self): "dnnl.dll", "mklml.dll", "libiomp5md.dll", - providers_cuda_or_rocm, + providers_cuda, providers_tensorrt_or_migraphx, providers_nv_tensorrt_rtx, providers_cann, @@ -680,21 +679,14 @@ def finalize_options(self): if cuda_version: # removing '.' to make Cuda version number in the same form as Pytorch. local_version = "+cu" + cuda_version.replace(".", "") - elif rocm_version: - # removing '.' to make Rocm version number in the same form as Pytorch. - local_version = "+rocm" + rocm_version.replace(".", "") else: # cpu version for documentation local_version = "+cpu" else: - if not (cuda_version or rocm_version): + if not cuda_version: # Training CPU package for ADO feeds is called onnxruntime-training-cpu package_name = "onnxruntime-training-cpu" - if rocm_version: - # Training ROCM package for ADO feeds is called onnxruntime-training-rocm - package_name = "onnxruntime-training-rocm" - if package_name == "onnxruntime-tvm": packages += ["onnxruntime.providers.tvm"] @@ -796,7 +788,7 @@ def reformat_run_count(count_str): install_requires.append(f"nvidia-cuda-runtime-cu{major}~={major}.0") -def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version, qnn_version): +def save_build_and_package_info(package_name, version_number, cuda_version, qnn_version): sys.path.append(path.join(path.dirname(__file__), "onnxruntime", "python")) from onnxruntime_collect_build_info import find_cudart_versions # noqa: PLC0415 @@ -823,13 +815,11 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm else "found multiple cudart libraries" ), ) - elif rocm_version: - f.write(f"rocm_version = '{rocm_version}'\n") elif qnn_version: f.write(f"qnn_version = '{qnn_version}'\n") -save_build_and_package_info(package_name, version_number, cuda_version, rocm_version, qnn_version) +save_build_and_package_info(package_name, version_number, cuda_version, qnn_version) extras_require = {} if package_name == "onnxruntime-gpu" and cuda_major_version: diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py deleted file mode 100644 index 6a8154681ed97..0000000000000 --- a/tools/ci_build/amd_hipify.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import argparse -import os -import subprocess - - -def hipify(hipify_perl_path, src_file_path, dst_file_path): - dir_name = os.path.dirname(dst_file_path) - if not os.path.exists(dir_name): - os.makedirs(dir_name, exist_ok=True) - # Run hipify-perl first, capture output - s = subprocess.run([hipify_perl_path, src_file_path], stdout=subprocess.PIPE, text=True, check=False).stdout - - # Additional exact-match replacements. - # Order matters for all of the following replacements, reglardless of appearing in logical sections. - s = s.replace("kCudaExecutionProvider", "kRocmExecutionProvider") - s = s.replace("CUDAStreamType", "HIPStreamType") - s = s.replace("kCudaStreamDefault", "kHipStreamDefault") - s = s.replace("kCudaStreamCopyIn", "kHipStreamCopyIn") - s = s.replace("kCudaStreamCopyOut", "kHipStreamCopyOut") - s = s.replace("kTotalCudaStreams", "kTotalHipStreams") - - # in rocm 6.0, hipify-perl, the -roc option also maps __half -> rocblas_half which we don't want - s = s.replace("rocblas_half", "__half") - - # these should be "hip" but it's easier to just use rocm to avoid complicated file renaming - s = s.replace("CudaGraph", "RocmGraph") - s = s.replace("CUDAGraph", "ROCMGraph") - s = s.replace("cuda_graph", "rocm_graph") - s = s.replace("RegisterCudaContribKernels", "RegisterRocmContribKernels") - s = s.replace("cudaEvent", "hipEvent") - s = s.replace("CreateCudaAllocator", "CreateRocmAllocator") - s = s.replace("CudaErrString", "RocmErrString") - s = s.replace("CudaAsyncBuffer", "RocmAsyncBuffer") - s = s.replace("CudaKernel", "RocmKernel") - s = s.replace("CudaStream", "RocmStream") - s = s.replace("ToCudaType", "ToHipType") - s = s.replace("CudaT", "HipT") - s = s.replace("CUDA_LONG", "HIP_LONG") - s = s.replace("CUDA_RETURN_IF_ERROR", "HIP_RETURN_IF_ERROR") - s = s.replace("CUDA_KERNEL_ASSERT", "HIP_KERNEL_ASSERT") - s = s.replace("CUDA_CALL", "HIP_CALL") - s = s.replace("SliceCuda", "SliceRocm") - s = s.replace("thrust::cuda", "thrust::hip") - s = s.replace("CudaCall", "RocmCall") - s = s.replace("cuda", "rocm") - # s = s.replace('Cuda', 'Rocm') - s = s.replace("CUDA", "ROCM") - s = s.replace("GPU_WARP_SIZE = 32", "GPU_WARP_SIZE = 64") - s = s.replace("std::exp", "expf") - s = s.replace("std::log", "logf") - s = s.replace("WaitCudaNotificationOnDevice", "WaitRocmNotificationOnDevice") - s = s.replace("hipHostAlloc", "hipHostMalloc") - s = s.replace( - "#include ", - "#include \n#include ", - ) - s = s.replace( - '#include "cub/device/device_radix_sort.cuh"', - "#include \n#include ", - ) - s = s.replace( - "#include ", - "#include ", - ) - s = s.replace( - "#include ", "#include " - ) - s = s.replace( - "#include ", - "#include ", - ) - s = s.replace("#include ", "#include ") - s = s.replace( - "#include ", - "#include ", - ) - s = s.replace( - "#include ", - "#include ", - ) - s = s.replace("#include ", "#include ") - s = s.replace('#include "cub/util_allocator.cuh"', "#include ") - s = s.replace("#include ", "#include ") - s = s.replace('#include "cub/util_type.cuh"', "#include ") - s = s.replace("#include ", "#include ") - s = s.replace("#include ", "#include ") - s = s.replace("#include ", "") # Doesn't exist - s = s.replace("typedef half MappedType", "typedef __half MappedType") - - # CUBLAS -> HIPBLAS - s = s.replace("CUBLAS", "HIPBLAS") - s = s.replace("Cublas", "Hipblas") - s = s.replace("cublas", "hipblas") - # deprecated cublas symbol doesn't exist in hipblas, map to new symbol - s = s.replace("HIPBLAS_GEMM_DEFAULT_TENSOR_OP", "HIPBLAS_GEMM_DEFAULT") - - # Undefined ROCMRT constants -> std::numeric_limits - s = s.replace("ROCMRT_INF_F", "std::numeric_limits::infinity()") - - # compatible layer - s = s.replace("rocblas_gemm_strided_batched_ex", "_compat_rocblas_gemm_strided_batched_ex") - s = s.replace("RocblasMathModeSetter", "CompatRocblasMathModeSetter") - - # CURAND -> HIPRAND - s = s.replace("CURAND", "HIPRAND") - s = s.replace("Curand", "Hiprand") - s = s.replace("curand", "hiprand") - - # NCCL -> RCCL - # s = s.replace('NCCL_CALL', 'RCCL_CALL') - s = s.replace("#include ", "#include ") - - # CUDNN -> MIOpen - s = s.replace("CUDNN", "MIOPEN") - s = s.replace("Cudnn", "Miopen") - s = s.replace("cudnn", "miopen") - # hipify seems to have a bug for MIOpen, cudnn.h -> hipDNN.h, cudnn -> hipdnn - s = s.replace("#include ", "#include ") - s = s.replace("hipdnn", "miopen") - s = s.replace("HIPDNN_STATUS_SUCCESS", "miopenStatusSuccess") - s = s.replace("HIPDNN", "MIOPEN") - s = s.replace("MIOPEN_BATCHNORM_SPATIAL", "miopenBNSpatial") - s = s.replace("MIOPEN_BATCHNORM_PER_ACTIVATION", "miopenBNPerActivation") - s = s.replace("MIOPEN_LRN_CROSS_CHANNEL", "miopenLRNCrossChannel") - s = s.replace("MIOPEN_POOLING_MAX", "miopenPoolingMax") - s = s.replace("MIOPEN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING", "miopenPoolingAverageInclusive") - s = s.replace("MIOPEN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING", "miopenPoolingAverage") - - # CUSPARSE -> HIPSPARSE - s = s.replace("CUSPARSE", "HIPSPARSE") - - # CUFFT -> HIPFFT - s = s.replace("CUFFT", "HIPFFT") - s = s.replace("cufftXtMakePlanMany", "hipfftXtMakePlanMany") - s = s.replace("cufftXtExec", "hipfftXtExec") - - # Undo where above hipify steps went too far. - s = s.replace("id, ROCM", "id, CUDA") # cuda_execution_provider.cc - s = s.replace("ROCM error executing", "HIP error executing") - s = s.replace("ROCM_PINNED", "CUDA_PINNED") - s = s.replace("rocm_err", "hip_err") - s = s.replace("RegisterHipTrainingKernels", "RegisterRocmTrainingKernels") - s = s.replace("ROCM_VERSION", "CUDA_VERSION") # semantically different meanings, cannot hipify - s = s.replace("__ROCM_ARCH__", "__CUDA_ARCH__") # semantically different meanings, cannot hipify - # "std::log" above incorrectly changed "std::logic_error" to "logfic_error" - s = s.replace("logfic_error", "std::logic_error") - - # Deletions - s = s.replace('#include "device_atomic_functions.h"', "") # HIP atomics in main hip header already - - # Fix warnings due to incorrect header paths, intentionally after all other hipify steps. - s = s.replace("#include ", "#include ") - s = s.replace("#include ", "#include ") - s = s.replace("#include ", "#include ") - s = s.replace("#include ", "#include ") - s = s.replace('#include "hipfft.h"', "#include ") - s = s.replace('#include "hipfftXt.h"', "#include ") - - # Fix onnxruntime/contrib_ops/rocm/transformers. They include cpu headers which use "cuda" in their names. - s = s.replace("rocm_device_prop_", "cuda_device_prop_") - s = s.replace("rocm_device_arch_", "cuda_device_arch_") - - s = s.replace("HipTuningContext", "RocmTuningContext") - - # We want hipfft, which needs hipDataType etc, but only do this for files that have "fft" in their names - # And we do this last, undoing or fixing hipify mistakes. - if "fft" in src_file_path: - s = s.replace("rocblas_datatype", "hipDataType") - s = s.replace("hipDataType_f32_c", "HIP_C_32F") - s = s.replace("hipDataType_f32_r", "HIP_R_32F") - s = s.replace("hipDataType_f64_c", "HIP_C_64F") - s = s.replace("hipDataType_f64_r", "HIP_R_64F") - s = s.replace("hipDataType_f16_c", "HIP_C_16F") - s = s.replace("hipDataType_f16_r", "HIP_R_16F") - - with open(dst_file_path, "w") as f: - f.write(s) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--hipify_perl", required=True) - parser.add_argument("--output", "-o", help="output file") - parser.add_argument("src", help="src") - args = parser.parse_args() - - hipify(args.hipify_perl, args.src, args.output) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 1fcd3fbe3daf0..77f07608528e2 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -272,8 +272,6 @@ def generate_vcpkg_install_options(build_dir, args): vcpkg_install_options.append("--x-feature=qnn-ep") if args.use_rknpu: vcpkg_install_options.append("--x-feature=rknpu-ep") - if args.use_rocm: - vcpkg_install_options.append("--x-feature=rocm-ep") if args.use_tensorrt: vcpkg_install_options.append("--x-feature=tensorrt-ep") if args.use_vitisai: @@ -347,7 +345,6 @@ def generate_build_tree( build_dir, cuda_home, cudnn_home, - rocm_home, nccl_home, tensorrt_home, tensorrt_rtx_home, @@ -375,7 +372,7 @@ def generate_build_tree( # enable/disable float 8 types disable_float8_types = args.android or ("float8" in types_to_disable) # enable/disable float 4 type - disable_float4_types = args.android or args.use_rocm or ("float4" in types_to_disable) + disable_float4_types = args.android or ("float4" in types_to_disable) disable_optional_type = "optional" in types_to_disable disable_sparse_tensors = "sparsetensor" in types_to_disable if is_windows(): @@ -1492,20 +1489,6 @@ def setup_dml_build(args, cmake_path, build_dir, configs): raise BuildError("use_dml and minimal_build may not both be set") -def setup_rocm_build(args): - rocm_home = None - if args.use_rocm: - print(f"rocm_home = {args.rocm_home}") - rocm_home = args.rocm_home or None - rocm_home_not_valid = rocm_home and not os.path.exists(rocm_home) - if rocm_home_not_valid: - raise BuildError( - "rocm_home paths must be specified and valid.", - f"rocm_home='{rocm_home}' valid={rocm_home_not_valid}.", - ) - return rocm_home or "" - - def run_android_tests(args, source_dir, build_dir, config, cwd): if args.android_abi != "x86_64": log.info(f"--android_abi ({args.android_abi}) is not x86_64, skipping running of Android tests on emulator.") @@ -1610,6 +1593,7 @@ def dump_logs_on_failure(): "libcustom_op_library.so", "libexample_plugin_ep_virt_gpu.so", "libexample_plugin_ep.so", + "libexample_plugin_ep_kernel_registry.so", "libonnxruntime_runtime_path_test_shared_library.so", "libonnxruntime.so", ] @@ -1760,12 +1744,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if is_windows(): cwd = os.path.join(cwd, config) - if ( - not args.skip_pip_install - and args.enable_transformers_tool_test - and not args.disable_contrib_ops - and not args.use_rocm - ): + if not args.skip_pip_install and args.enable_transformers_tool_test and not args.disable_contrib_ops: # PyTorch is required for transformers tests, and optional for some python tests. # Install cpu only version of torch when cuda is not enabled in Linux. extra = [] if args.use_cuda and is_linux() else ["--index-url", "https://download.pytorch.org/whl/cpu"] @@ -1946,9 +1925,7 @@ def build_python_wheel( use_cuda, cuda_home, cuda_version, - use_rocm, use_migraphx, - rocm_version, use_dnnl, use_tensorrt, use_openvino, @@ -1994,12 +1971,6 @@ def build_python_wheel( cuda_version = cuda_version or parse_cuda_version_from_json(cuda_home) if cuda_version: args.append(f"--cuda_version={cuda_version}") - elif use_rocm: - args.append("--use_rocm") - if rocm_version: - args.append(f"--rocm_version={rocm_version}") - if use_migraphx: - args.append("--use_migraphx") elif use_migraphx: args.append("--use_migraphx") elif use_openvino: @@ -2039,7 +2010,6 @@ def build_nuget_package( build_dir, configs, use_cuda, - use_rocm, use_openvino, use_tensorrt, use_dnnl, @@ -2094,8 +2064,6 @@ def build_nuget_package( package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu" elif use_dml: package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.DirectML" - elif use_rocm: - package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm" elif use_qnn: if use_qnn != "shared_lib": raise BuildError("Currently NuGet packages with QNN require QNN EP to be built as a shared library.") @@ -2447,9 +2415,6 @@ def main(): # if using migraphx, setup migraphx paths migraphx_home = setup_migraphx_vars(args) - # if using rocm, setup rocm paths - rocm_home = setup_rocm_build(args) - # if using cann, setup cann paths cann_home = setup_cann_vars(args) @@ -2572,16 +2537,12 @@ def main(): cwd=SCRIPT_DIR, ) - if args.use_rocm and args.rocm_version is None: - args.rocm_version = "" - generate_build_tree( cmake_path, source_dir, build_dir, cuda_home, cudnn_home, - rocm_home, nccl_home, tensorrt_home, tensorrt_rtx_home, @@ -2642,9 +2603,7 @@ def main(): args.use_cuda, cuda_home, args.cuda_version, - args.use_rocm, args.use_migraphx, - args.rocm_version, args.use_dnnl, args.use_tensorrt, args.use_openvino, @@ -2672,7 +2631,6 @@ def main(): build_dir, configs, args.use_cuda, - args.use_rocm, args.use_openvino, args.use_tensorrt, args.use_dnnl, diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index 6763973406294..cd652a6cbb82e 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -763,9 +763,6 @@ def add_execution_provider_args(parser: argparse.ArgumentParser) -> None: migx_group = parser.add_argument_group("MIGraphX Execution Provider") migx_group.add_argument("--use_migraphx", action="store_true", help="Enable MIGraphX EP.") migx_group.add_argument("--migraphx_home", help="Path to MIGraphX installation directory.") - migx_group.add_argument("--use_rocm", action="store_true", help="Enable ROCm EP.") - migx_group.add_argument("--rocm_version", help="ROCm stack version.") - migx_group.add_argument("--rocm_home", help="Path to ROCm installation directory.") # --- WebNN --- webnn_group = parser.add_argument_group("WebNN Execution Provider") diff --git a/tools/ci_build/gen_def.py b/tools/ci_build/gen_def.py index 526cc7bde519e..46cbac20627f7 100755 --- a/tools/ci_build/gen_def.py +++ b/tools/ci_build/gen_def.py @@ -71,7 +71,6 @@ def parse_arguments(): "vitisai", "winml", "cuda", - "rocm", "migraphx", "qnn", "snpe", diff --git a/tools/ci_build/github/linux/build_rocm_c_api_package.sh b/tools/ci_build/github/linux/build_rocm_c_api_package.sh deleted file mode 100755 index 3ea90c73342a5..0000000000000 --- a/tools/ci_build/github/linux/build_rocm_c_api_package.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash - -set -e -u -x - -usage() { echo "Usage: $0 -S -B -V [-H ] " 1>&2; exit 1; } - -ROCM_HOME=/opt/rocm - -while getopts S:B:V:H:I:P: parameter_Option; do - case "${parameter_Option}" in - S) SOURCE_DIR=${OPTARG};; - B) BINARY_DIR=${OPTARG};; - V) ROCM_VERSION=${OPTARG};; - H) ROCM_HOME=${OPTARG};; - I) IMAGE=${OPTARG};; - P) PYTHON_BIN=${OPTARG};; - *) usage ;; - esac -done - -EXIT_CODE=1 - -docker run -e SYSTEM_COLLECTIONURI --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --user $UID:$(id -g $USER) \ - -e NIGHTLY_BUILD \ - --volume $SOURCE_DIR:/onnxruntime_src \ - --volume $BINARY_DIR:/build \ - --volume /data/models:/build/models:ro \ - --volume /data/onnx:/data/onnx:ro \ - --workdir /onnxruntime_src \ - $IMAGE \ - /bin/bash -c "${PYTHON_BIN:-python} /onnxruntime_src/tools/ci_build/build.py --config Release --build_dir /build --parallel --use_rocm --use_binskim_compliant_compile_flags --rocm_version=$ROCM_VERSION --rocm_home $ROCM_HOME --nccl_home $ROCM_HOME --build_shared_lib --skip_submodule_sync --skip_tests --cmake_extra_defines FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER && cd /build/Release && make install DESTDIR=/build/installed" - - -EXIT_CODE=$? - -set -e -exit $EXIT_CODE diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile index 44b14d31919b2..e928801be858c 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile @@ -19,8 +19,8 @@ RUN dnf install -y --nodocs \ && dnf clean all \ && rm -rf /var/cache/dnf -ENV INTEL_OPENVINO_DIR=/opt/intel/openvino_2025.2.0 -ARG OPENVINO_PACKAGE_URL=https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.2/linux/openvino_toolkit_rhel8_2025.2.0.19140.c01cd93e24d_x86_64.tgz +ENV INTEL_OPENVINO_DIR=/opt/intel/openvino_2025.3.0 +ARG OPENVINO_PACKAGE_URL=https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.3/linux/openvino_toolkit_rhel8_2025.3.0.19807.44526285f24_x86_64.tgz ARG TEMP_DIR=/tmp/openvino_installer RUN mkdir -p ${TEMP_DIR} && \ diff --git a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh deleted file mode 100755 index 0be64d96f3a34..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash -set -e -x - -# version -ROCM_VERSION=6.2.3 - -while getopts "r:" parameter_Option -do case "${parameter_Option}" -in -r) ROCM_VERSION=${OPTARG};; -esac -done - -tee /etc/yum.repos.d/amdgpu.repo <] [-d ] [-r ]" 1>&2; exit 1; } - -while getopts "n:d:r:" parameter_Option -do case "${parameter_Option}" -in -n) AGENT_NAME=${OPTARG};; -d) TARGET_DEVICE=${OPTARG};; -r) DRIVER_RENDER=${OPTARG};; -*) usage ;; -esac -done - -echo "Agent Name: $AGENT_NAME, Target Device: $TARGET_DEVICE, Driver Render: $DRIVER_RENDER" - -echo -e "\n ---- Execute rocm-smi" -rocm-smi - -echo -e "\n ---- Execute rocm-smi --showpids" -rocm-smi --showpids - -echo -e "\n ---- Execute rocm-smi --showpidgpus" -rocm-smi --showpidgpus - -echo -e "\n ---- Execute rocm-smi --showpids detail" -rocm-smi --showpids | awk '$1 ~/[0-9]+/{if((NR>6)) {print $1}}' | xargs -I {} ps {} - -echo -e "\n ---- Execute rocm-smi --showmeminfo" -rocm-smi --showmeminfo vram vis_vram gtt - -echo -e "\n ---- Clean up processes that use the target device $TARGET_DEVICE" -GPU_USED_BY_PIDS=$(rocm-smi --showpidgpus) -PID_NUMBERS_LINES=$(echo "$GPU_USED_BY_PIDS" | grep -n "DRM device" | cut -d ":" -f 1) -PID_NUMBERS_LINES_ARRAY=($PID_NUMBERS_LINES) - -for ((i = 0; i < ${#PID_NUMBERS_LINES_ARRAY[@]}; i++)); do - PID_NUMBER_LINE=${PID_NUMBERS_LINES_ARRAY[$i]} - PID_NUMBER=$(echo "$GPU_USED_BY_PIDS" | awk '{print $2}' | sed -n "${PID_NUMBER_LINE}p") - GPU_USED_BY_PID_LINE=$((PID_NUMBER_LINE + 1)) - GPU_USED_BY_PID=$(echo "$GPU_USED_BY_PIDS" | sed -n "${GPU_USED_BY_PID_LINE}p" | sed -e 's/^[ ]*//g' | sed -e 's/[ ]*$//g') - if [ "$GPU_USED_BY_PID" == "$TARGET_DEVICE" ]; then - echo "kill pid: $PID_NUMBER, using gpu: $GPU_USED_BY_PID" - kill -9 "$PID_NUMBER" - fi -done diff --git a/tools/ci_build/policheck_exclusions.xml b/tools/ci_build/policheck_exclusions.xml index a24eed809c5b8..9888245b48674 100644 --- a/tools/ci_build/policheck_exclusions.xml +++ b/tools/ci_build/policheck_exclusions.xml @@ -1,4 +1,4 @@ - LABELMAP.CS|OPERATORKERNELS.MD|BABEL.CONFIG.JS|METRO.CONFIG.JS|DMLOPERATORACTIVATION.CPP|DATA_OPS.CC|ONNX_CONVERTER.CC|ONNXOPS.PY|CPYTHON-PUBKEYS.TXT|AMD_HIPIFY.PY + LABELMAP.CS|OPERATORKERNELS.MD|BABEL.CONFIG.JS|METRO.CONFIG.JS|DMLOPERATORACTIVATION.CPP|DATA_OPS.CC|ONNX_CONVERTER.CC|ONNXOPS.PY|CPYTHON-PUBKEYS.TXT diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index 899aaaa95216a..98e3e6a9b05b2 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -22,7 +22,6 @@ "linux-migraphx-ci-pipeline.yml", "linux-openvino-ci-pipeline.yml", "linux-qnn-ci-pipeline.yml", - "linux-rocm-ci-pipeline.yml", "mac-ci-pipeline.yml", "mac-coreml-ci-pipeline.yml", "mac-ios-ci-pipeline.yml", diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 6ce8c3b0bca91..9884cbf5793df 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -27,8 +27,6 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-cuda" elif ep == "tensorrt": pkg_name += "-tensorrt" - elif ep == "rocm": - pkg_name += "-rocm" elif ep == "migraphx": pkg_name += "-migraphx" elif os == "linux": @@ -38,8 +36,6 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-cuda" elif ep == "tensorrt": pkg_name += "-tensorrt" - elif ep == "rocm": - pkg_name += "-rocm" elif ep == "migraphx": pkg_name += "-migraphx" elif os == "osx": @@ -375,7 +371,6 @@ def generate_files(line_list, args): is_cuda_gpu_package = args.package_name == "Microsoft.ML.OnnxRuntime.Gpu" is_cuda_gpu_win_sub_package = args.package_name == "Microsoft.ML.OnnxRuntime.Gpu.Windows" is_cuda_gpu_linux_sub_package = args.package_name == "Microsoft.ML.OnnxRuntime.Gpu.Linux" - is_rocm_gpu_package = args.package_name == "Microsoft.ML.OnnxRuntime.ROCm" is_dml_package = args.package_name == "Microsoft.ML.OnnxRuntime.DirectML" is_windowsai_package = args.package_name == "Microsoft.AI.MachineLearning" is_snpe_package = args.package_name == "Microsoft.ML.OnnxRuntime.Snpe" @@ -440,7 +435,6 @@ def generate_files(line_list, args): "tensorrt_ep_shared_lib": "libonnxruntime_providers_tensorrt.so", "openvino_ep_shared_lib": "libonnxruntime_providers_openvino.so", "cuda_ep_shared_lib": "libonnxruntime_providers_cuda.so", - "rocm_ep_shared_lib": "libonnxruntime_providers_rocm.so", "migraphx_ep_shared_lib": "libonnxruntime_providers_migraphx.so", "onnxruntime_perf_test": "onnxruntime_perf_test", "onnx_test_runner": "onnx_test_runner", @@ -631,8 +625,6 @@ def generate_files(line_list, args): # downloaded from other build jobs if is_cuda_gpu_package or is_cuda_gpu_win_sub_package or is_cuda_gpu_linux_sub_package: ep_list = ["tensorrt", "cuda", None] - elif is_rocm_gpu_package: - ep_list = ["rocm", None] elif is_migraphx_package: ep_list = ["migraphx", None] else: @@ -742,24 +734,6 @@ def generate_files(line_list, args): + '\\native" />' ) - if args.execution_provider == "rocm" or (is_rocm_gpu_package and not is_ado_packaging_build): - files_list.append( - "' - ) - files_list.append( - "' - ) - if args.execution_provider == "openvino": openvino_path = get_env_var("INTEL_OPENVINO_DIR") files_list.append( @@ -998,7 +972,6 @@ def _files_list_append(key: str): or is_cuda_gpu_package or is_cuda_gpu_linux_sub_package or is_cuda_gpu_win_sub_package - or is_rocm_gpu_package or is_migraphx_package or is_dml_package or is_mklml_package @@ -1240,12 +1213,11 @@ def validate_execution_provider(execution_provider): or execution_provider == "cuda" or execution_provider == "tensorrt" or execution_provider == "openvino" - or execution_provider == "rocm" or execution_provider == "migraphx" ): raise Exception( "On Linux platform nuget generation is supported only " - "for cpu|cuda|dnnl|tensorrt|openvino|rocm execution providers." + "for cpu|cuda|dnnl|tensorrt|openvino|migraphx execution providers." )