diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 6c3f2eb0fbbe1..725c40c2ded53 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -24,7 +24,7 @@ jobs: name: Download prebuilt ONNX Runtime archive from build.rs runs-on: ubuntu-latest env: - ORT_RUST_STRATEGY=download + ORT_RUST_STRATEGY: download steps: - uses: actions/checkout@v4 - uses: ./.github/actions/rust-toolchain-setup diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 95607f297c6bd..c94e3fa5bcb8c 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v8.0.0 + - uses: actions/stale@v9.0.0 with: # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: contributions welcome, feature request, regression @@ -29,7 +29,7 @@ jobs: # Label you want to apply to issues that have been inactive for the amount of time specified by days-before-issue-stale stale-issue-label: "stale" # Comment that you want to add to issues that are labeled by the actions/stale action - stale-issue-message: "This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details." + stale-issue-message: "This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details." # Comment that you want to add to issues that are closed by the actions/stale action close-issue-message: "This issue has been automatically closed due to inactivity. Please reactivate if further support is needed." # If you never want this action to label PRs, set this value to -1 diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index ba24e7eebfb03..3a780f87d2300 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -49,13 +49,10 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - uses: actions/setup-python@v4 - with: - python-version: '3.8.x' - architecture: 'x64' - uses: conda-incubator/setup-miniconda@v2 with: - activate-environment: "" + activate-environment: "ort_build" + python-version: 3.8 - name: 'Install LLVM-Dev' shell: pwsh run: | diff --git a/.pipelines/windowsai-steps.yml b/.pipelines/windowsai-steps.yml index 45ebf889c5da1..292ce60c6b6cf 100644 --- a/.pipelines/windowsai-steps.yml +++ b/.pipelines/windowsai-steps.yml @@ -84,7 +84,7 @@ jobs: 7z x cmake-3.26.3-windows-x86_64.zip set PYTHONHOME=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools set PYTHONPATH=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools - $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe + $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Generate cmake config' diff --git a/.vscode/settings.json b/.vscode/settings.json index c4a08e3232a82..2f2adc78f6de9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -13,6 +13,7 @@ "editor.codeActionsOnSave": { "source.organizeImports": true }, + "editor.defaultFormatter": "ms-python.black-formatter" }, // Enable Python linting and Pylance type checking "python.analysis.typeCheckingMode": "basic", diff --git a/README.md b/README.md index 22ef387f5a7cd..33bce867e3bde 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ |Android|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Android%20CI%20Pipeline?label=Android)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=53)|| |iOS|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/iOS%20CI%20Pipeline?label=iOS)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=134)|| |Web|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/ONNX%20Runtime%20Web%20CI%20Pipeline?label=Web)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=161)|| -|Other|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime-binary-size-checks-ci-pipeline?repoName=microsoft%2Fonnxruntime&label=Binary+Size+Check)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=187&repoName=microsoft%2Fonnxruntime)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime-python-checks-ci-pipeline?label=Python+Checks)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=164)|| +|Other|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime-binary-size-checks-ci-pipeline?repoName=microsoft%2Fonnxruntime&label=Binary+Size+Check)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=187&repoName=microsoft%2Fonnxruntime)|| ## Third-party Pipeline Status diff --git a/build_arm64x.bat b/build_arm64x.bat new file mode 100644 index 0000000000000..fbcdd373086a9 --- /dev/null +++ b/build_arm64x.bat @@ -0,0 +1,12 @@ +:: Copyright (c) Microsoft Corporation. All rights reserved. +:: Licensed under the MIT License. + +@echo off + +setlocal +set PATH=C:\Program Files\Git\usr\bin;%PATH% +set LINK_REPRO_NAME=/mylink.rsp + +rem Requires a Python install to be available in your PATH +python "%~dp0\tools\ci_build\build.py" --arm64 --buildasx --build_dir "%~dp0\build\arm64-x" %* +python "%~dp0\tools\ci_build\build.py" --arm64ec --buildasx --build_dir "%~dp0\build\arm64ec-x" %* diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 12fbb291c3a70..137ea8a50c011 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -36,7 +36,7 @@ "component": { "type": "git", "git": { - "commitHash": "29bf8085f3bf17b84d30e34b3d7ff8248fda404e", + "commitHash": "dcd5bd5fd593e31465af3d9ef291d26c646b0a4f", "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" }, "comments": "abseil_cpp" @@ -126,7 +126,7 @@ "component": { "type": "git", "git": { - "commitHash": "f8d7d77c06936315286eb55f8de22cd23c188571", + "commitHash": "530d5c8c84abd2a46f38583ee817743c9b3a42b4", "repositoryUrl": "https://github.com/google/googletest.git" }, "comments": "googletest" @@ -316,7 +316,7 @@ "component": { "type": "git", "git": { - "commitHash": "a4f72a314a85732ed67d5aa8d1088d207a7e0e61", + "commitHash": "5356c4a943a35e74d7cdc69486afcb8703b9a59a", "repositoryUrl": "https://github.com/ROCmSoftwarePlatform/composable_kernel.git" }, "comments": "composable_kernel" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index a9dc15b319c37..4a98849c05ef1 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -87,6 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) +option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) @@ -1166,6 +1167,17 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() +set(USE_JBLAS FALSE) +if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD) + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") + add_compile_definitions(MLAS_JBLAS) + set(USE_JBLAS TRUE) + elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") + add_compile_definitions(MLAS_JBLAS) + set(USE_JBLAS TRUE) + endif() +endif() + # TVM EP if (onnxruntime_USE_TVM) if (NOT TARGET tvm) @@ -1269,7 +1281,7 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_2023_1=1) elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.2") set(OPENVINO_VERSION "2023.2") - add_definitions(-DOPENVINO_2023_1=1) + add_definitions(-DOPENVINO_2023_2=1) elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "openvino") set(OPENVINO_VERSION "2023.2") add_definitions(-DOPENVINO_2023_2=1) @@ -1293,6 +1305,14 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_CONFIG_CPU_FP16=1) endif() + if (onnxruntime_USE_OPENVINO_NPU_FP16) + add_definitions(-DOPENVINO_CONFIG_NPU_FP16=1) + endif() + + if (onnxruntime_USE_OPENVINO_NPU_U8) + add_definitions(-DOPENVINO_CONFIG_NPU_U8=1) + endif() + if (onnxruntime_USE_OPENVINO_GPU_FP32_NP) add_definitions(-DOPENVINO_CONFIG_GPU_FP32=1) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) @@ -1313,6 +1333,16 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) endif() + if (onnxruntime_USE_OPENVINO_NPU_FP16_NP) + add_definitions(-DOPENVINO_CONFIG_NPU_FP16=1) + add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) + endif() + + if (onnxruntime_USE_OPENVINO_NPU_U8_NP) + add_definitions(-DOPENVINO_CONFIG_NPU_U8=1) + add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) + endif() + if (onnxruntime_USE_OPENVINO_HETERO) add_definitions(-DOPENVINO_CONFIG_HETERO=1) add_definitions(-DDEVICE_NAME="${onnxruntime_USE_OPENVINO_DEVICE}") @@ -1584,6 +1614,13 @@ set(VERSION_STRING "Internal Build" CACHE STRING "String representation of if (WIN32) list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${SYS_PATH_LIB}) list(APPEND onnxruntime_EXTERNAL_LIBRARIES debug Dbghelp) + # In a onecore build the umbrella libs already contains references to the APIs in advapi32, so in onecore build we do not need to link to advapi32 + # In a non-onecore build, usually we also do not need to link to advapi32 because VC++ by default should have provide everything we need, except when the build target is Windows ARM32. + # In the future we will add a build option to allow users disabling all API uses from advapi32 because some Windows environments do not have these APIs. For example, some Windows do not have + # Windows Registry so we cannot query Registry values. + if(onnxruntime_target_platform STREQUAL "ARM" AND CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES advapi32) + endif() else() list(APPEND onnxruntime_EXTERNAL_LIBRARIES nsync::nsync_cpp) list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${ICONV_LIB} ${CMAKE_DL_LIBS} Threads::Threads) @@ -1773,3 +1810,8 @@ if(TARGET onnxruntime) "${PROJECT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake" DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}") endif() + +if(DEFINED BUILD_AS_ARM64X) + set(ARM64X_TARGETS onnxruntime) + include("${CMAKE_SOURCE_DIR}/arm64x.cmake") +endif() diff --git a/cmake/arm64x.cmake b/cmake/arm64x.cmake new file mode 100644 index 0000000000000..be476e09625bd --- /dev/null +++ b/cmake/arm64x.cmake @@ -0,0 +1,33 @@ +set(arm64ReproDir "${CMAKE_SOURCE_DIR}/repros") + +if("${BUILD_AS_ARM64X}" STREQUAL "ARM64") + foreach (n ${ARM64X_TARGETS}) + add_custom_target(mkdirs_${n} ALL COMMAND cmd /c (if exist \"${arm64ReproDir}/${n}_temp/\" rmdir /s /q \"${arm64ReproDir}/${n}_temp\") && mkdir \"${arm64ReproDir}/${n}_temp\" ) + add_dependencies(${n} mkdirs_${n}) + target_link_options(${n} PRIVATE "/LINKREPRO:${arm64ReproDir}/${n}_temp") + add_custom_target(${n}_checkRepro ALL COMMAND cmd /c if exist \"${n}_temp/*.obj\" if exist \"${n}\" rmdir /s /q \"${n}\" 2>nul && if not exist \"${n}\" ren \"${n}_temp\" \"${n}\" DEPENDS ${n} + WORKING_DIRECTORY ${arm64ReproDir}) + endforeach() + + +elseif("${BUILD_AS_ARM64X}" STREQUAL "ARM64EC") + foreach (n ${ARM64X_TARGETS}) + set(ARM64_LIBS) + set(ARM64_OBJS) + set(ARM64_DEF) + + file(GLOB ARM64_OBJS "${arm64ReproDir}/${n}/*.obj") + file(GLOB ARM64_DEF "${arm64ReproDir}/${n}/*.def") + file(GLOB ARM64_LIBS "${arm64ReproDir}/${n}/*.LIB") + + if(NOT "${ARM64_DEF}" STREQUAL "") + set(ARM64_DEF "/defArm64Native:${ARM64_DEF}") + endif() + target_sources(${n} PRIVATE ${ARM64_OBJS}) + target_link_options(${n} PRIVATE /machine:arm64x "${ARM64_DEF}") + + if(NOT "${ARM64_LIBS}" STREQUAL "") + target_link_libraries(${n} PUBLIC ${ARM64_LIBS}) + endif() + endforeach() +endif() diff --git a/cmake/deps.txt b/cmake/deps.txt index 49142372ab86e..ff07803013071 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -12,7 +12,7 @@ # NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # -abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.0.zip;04271dfbfac59269b6939e1e9d5faf0d18a7ba91 +abseil_cpp;https://github.com/abseil/abseil-cpp/archive/dcd5bd5fd593e31465af3d9ef291d26c646b0a4f.zip;6cc204586014e189f5c0fe3274f83162fa7c700c cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 @@ -27,7 +27,7 @@ fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908 google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 -googletest;https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip;0ac421f2ec11af38b0fff0f1992184032731a8bc +googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034 googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 @@ -54,4 +54,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/a4f72a314a85732ed67d5aa8d1088d207a7e0e61.zip;f57357ab6d300e207a632d034ebc8aa036a090d9 +composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index 708d6ba18750b..1e5a36fb9efb9 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -30,7 +30,6 @@ - empty size={ _size() } size=({_size()}) diff --git a/cmake/external/dnnl.cmake b/cmake/external/dnnl.cmake index 397c4d6abeb9a..d7b70640781d0 100644 --- a/cmake/external/dnnl.cmake +++ b/cmake/external/dnnl.cmake @@ -25,6 +25,16 @@ elseif(onnxruntime_USE_DNNL AND onnxruntime_DNNL_GPU_RUNTIME STREQUAL "ocl" AND set(DNNL_GPU_CMAKE_ARGS "-DDNNL_GPU_RUNTIME=OCL " "-DOPENCLROOT=${onnxruntime_DNNL_OPENCL_ROOT}") endif() +if(onnxruntime_USE_DNNL AND onnxruntime_DNNL_AARCH64_RUNTIME STREQUAL "acl" AND onnxruntime_DNNL_ACL_ROOT STREQUAL "") + message(FATAL_ERROR "--dnnl_acl_root required") +elseif(onnxruntime_USE_DNNL AND onnxruntime_DNNL_AARCH64_RUNTIME STREQUAL "" AND NOT (onnxruntime_DNNL_ACL_ROOT STREQUAL "")) + message(FATAL_ERROR "--dnnl_aarch64_runtime required") +elseif(onnxruntime_USE_DNNL AND onnxruntime_DNNL_AARCH64_RUNTIME STREQUAL "acl" AND NOT (onnxruntime_DNNL_ACL_ROOT STREQUAL "")) + file(TO_CMAKE_PATH ${onnxruntime_DNNL_ACL_ROOT} onnxruntime_DNNL_ACL_ROOT) + set(ACL_INCLUDE_DIR ${onnxruntime_DNNL_ACL_ROOT}/arm_compute) + set(DNNL_AARCH64_CMAKE_ARGS "-DDNNL_AARCH64_USE_ACL=ON") +endif() + if (onnxruntime_USE_DNNL) set(DNNL_SOURCE ${CMAKE_CURRENT_BINARY_DIR}/dnnl/src/dnnl/src) set(DNNL_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/dnnl/install) @@ -51,7 +61,7 @@ if (onnxruntime_USE_DNNL) GIT_TAG ${DNNL_TAG} # PATCH_COMMAND ${MKLDNN_PATCH_DISCARD_COMMAND} COMMAND ${DNNL_PATCH_COMMAND} SOURCE_DIR ${DNNL_SOURCE} - CMAKE_ARGS -DDNNL_BUILD_TESTS=OFF -DDNNL_ENABLE_CONCURRENT_EXEC=ON -DDNNL_BUILD_EXAMPLES=OFF -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${DNNL_INSTALL} ${DNNL_GPU_CMAKE_ARGS} + CMAKE_ARGS -DDNNL_BUILD_TESTS=OFF -DDNNL_ENABLE_CONCURRENT_EXEC=ON -DDNNL_BUILD_EXAMPLES=OFF -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${DNNL_INSTALL} ${DNNL_GPU_CMAKE_ARGS} ${DNNL_AARCH64_CMAKE_ARGS} ) link_directories(${DNNL_LIB_DIR}) endif() diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 0fa5163dc06bf..78f63227c8392 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -47,8 +47,8 @@ if (onnxruntime_BUILD_UNIT_TESTS) FetchContent_Declare( googletest URL ${DEP_URL_googletest} - FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest URL_HASH SHA1=${DEP_SHA1_googletest} + FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest ) endif() @@ -124,7 +124,7 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) if(protoc_binary_SOURCE_DIR) message("Use prebuilt protoc") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") @@ -140,7 +140,7 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) if(protoc_binary_SOURCE_DIR) message("Use prebuilt protoc") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() elseif ((CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin") FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal}) @@ -281,7 +281,7 @@ if ((CPUINFO_SUPPORTED OR onnxruntime_USE_XNNPACK) AND NOT ANDROID) pytorch_clog URL ${DEP_URL_pytorch_cpuinfo} URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} - SOURCE_SUBDIR deps/clog + SOURCE_SUBDIR deps/clog ) set(ONNXRUNTIME_CLOG_PROJ pytorch_clog) set(ONNXRUNTIME_CLOG_TARGET_NAME clog) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 9d9b006c595bb..c900f4d4b09a5 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -282,11 +282,7 @@ endif() # Assemble the Apple static framework (iOS and macOS) if(onnxruntime_BUILD_APPLE_FRAMEWORK) - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT}) - else() # macOS - set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) - endif() + set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT}) # Setup the various directories required. Remove any existing ones so we start with a clean directory. set(STATIC_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/static_libraries) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 04efa5c2b4f6d..bee83ff07c74b 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -45,6 +45,15 @@ endif() set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) +function(add_jblas) + add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) + target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/jblas_gemm.cpp + ) + set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF) +endfunction() + #TODO: set MASM flags properly function(setup_mlas_source_for_windows) @@ -200,7 +209,6 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/q4gemm_avx512.cpp ) endif() - else() target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp @@ -284,6 +292,8 @@ else() set(X86 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") set(X86_64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") + set(LOONGARCH64 TRUE) endif() endif() @@ -564,7 +574,7 @@ else() ) set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() + endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs}) @@ -575,6 +585,26 @@ else() set(MLAS_SOURCE_IS_NOT_SET 0) endif() endif() + if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S + ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S + ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx") + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) file(GLOB_RECURSE mlas_platform_srcs "${MLAS_SRC_DIR}/scalar/*.cpp") @@ -582,6 +612,10 @@ else() target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) endif() +if(USE_JBLAS) + add_jblas() +endif() + foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR}) onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index baea52e84ace2..6f09583199ffd 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -86,6 +86,8 @@ if (onnxruntime_ENABLE_TRAINING) "${ORTTRAINING_SOURCE_DIR}/core/optimizer/*.cc" "${ORTTRAINING_SOURCE_DIR}/core/optimizer/compute_optimizer/*.h" "${ORTTRAINING_SOURCE_DIR}/core/optimizer/compute_optimizer/*.cc" + "${ORTTRAINING_SOURCE_DIR}/core/optimizer/memory_optimizer/*.h" + "${ORTTRAINING_SOURCE_DIR}/core/optimizer/memory_optimizer/*.cc" ) endif() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index f2a16fb29dc62..84d1376f99d5e 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -34,6 +34,8 @@ if (NOT onnxruntime_USE_NCCL) list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/nccl_kernels.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharded_moe.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharded_moe.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding_spec.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc" @@ -172,10 +174,8 @@ target_link_libraries(${target} PRIVATE cuda) endif() - if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) - include(cutlass) - target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) - endif() + include(cutlass) + target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake index 7ac4a82c89a76..0951c2d02664d 100644 --- a/cmake/onnxruntime_providers_vitisai.cmake +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -15,16 +15,10 @@ "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" ) - list(REMOVE_ITEM onnxruntime_providers_vitisai_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc") source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs}) onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto) - onnxruntime_add_shared_library(onnxruntime_vitisai_ep ${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc) - onnxruntime_add_include_to_target(onnxruntime_vitisai_ep onnxruntime_common) - target_include_directories(onnxruntime_vitisai_ep PRIVATE "${ONNXRUNTIME_ROOT}" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include") - target_link_libraries(onnxruntime_providers_vitisai PUBLIC onnxruntime_vitisai_ep PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json ) - target_compile_definitions(onnxruntime_vitisai_ep - PRIVATE "-DONNXRUNTIME_VITISAI_EP_STUB=1" "-DONNXRUNTIME_VITISAI_EP_EXPORT_DLL=1") + target_link_libraries(onnxruntime_providers_vitisai PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json) if(NOT MSVC) target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) endif(NOT MSVC) @@ -49,4 +43,4 @@ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index a9a78668b4810..61922961588b2 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -339,9 +339,6 @@ configure_file(${ONNXRUNTIME_ROOT}/python/_pybind_state.py.in ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py) if (onnxruntime_ENABLE_TRAINING) - file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/python/deprecated/*.py" - ) file(GLOB onnxruntime_python_root_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/*.py" ) @@ -419,10 +416,6 @@ if (onnxruntime_ENABLE_TRAINING) "${ORTTRAINING_SOURCE_DIR}/python/training/onnxblock/optim/*" ) endif() -else() - file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/python/training/*.py" - ) endif() if (onnxruntime_BUILD_UNIT_TESTS) @@ -443,6 +436,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx" ) + file(GLOB onnxruntime_python_transformers_testdata_conformer CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx" + ) endif() file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS @@ -457,6 +453,12 @@ file(GLOB onnxruntime_python_quantization_operators_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_quantization_cal_table_flatbuffers_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/quantization/CalTableFlatBuffers/*.py" ) +file(GLOB onnxruntime_python_quantization_fusions_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/quantization/fusions/*.py" +) +file(GLOB onnxruntime_python_quantization_ep_qnn_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/quantization/execution_providers/qnn/*.py" +) file(GLOB onnxruntime_python_transformers_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/*.py" ) @@ -551,11 +553,15 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/operators COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/CalTableFlatBuffers + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/fusions + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/execution_providers + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/execution_providers/qnn COMMAND ${CMAKE_COMMAND} -E make_directory $/quantization COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/whisper COMMAND ${CMAKE_COMMAND} -E make_directory $/eager_test + COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/conformer COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_ROOT}/__init__.py $/onnxruntime/ @@ -577,9 +583,6 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py $/onnxruntime/capi/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_capi_training_srcs} - $/onnxruntime/capi/training/ COMMAND ${CMAKE_COMMAND} -E copy $ $/onnxruntime/capi/ @@ -623,6 +626,12 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_quantization_cal_table_flatbuffers_src} $/onnxruntime/quantization/CalTableFlatBuffers/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_quantization_fusions_src} + $/onnxruntime/quantization/fusions/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_quantization_ep_qnn_src} + $/onnxruntime/quantization/execution_providers/qnn/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_src} $/onnxruntime/transformers/ @@ -711,6 +720,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_testdata_whisper} $/transformers/test_data/models/whisper/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_testdata_conformer} + $/transformers/test_data/models/conformer/ ) endif() @@ -750,9 +762,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils/data/ COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils/hooks/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_capi_training_srcs} - $/onnxruntime/capi/training/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_root_srcs} $/onnxruntime/training/ diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 980bd59b22c3f..f70961a66329a 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -109,6 +109,8 @@ if (NOT onnxruntime_USE_NCCL) # Those are string patterns to exclude. Do NOT use stars such as # collective/*.cc or *.h. list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc") + list(APPEND contrib_ops_excluded_files "collective/sharded_moe.h") + list(APPEND contrib_ops_excluded_files "collective/sharded_moe.cc") list(APPEND contrib_ops_excluded_files "collective/sharding.cc") list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index a52e941b235b4..7c8c70f913dca 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -783,7 +783,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) - target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock) + target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut) endif() @@ -1373,56 +1373,55 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(compare_two_sessions PRIVATE ${GETOPT_LIB_WIDE} tdh Advapi32) endif() - file(GLOB onnxruntime_mlas_test_src CONFIGURE_DEPENDS - "${TEST_SRC_DIR}/mlas/unittest/*.h" - "${TEST_SRC_DIR}/mlas/unittest/*.cpp" - ) - onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src}) - if(MSVC) - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>" - "$<$>:/wd26409>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" - "$<$>:/utf-8>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>" - "$<$>:/wd6326>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" - "$<$>:/wd26426>") - endif() - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set_target_properties(onnxruntime_mlas_test PROPERTIES - XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" + if(NOT onnxruntime_target_platform STREQUAL "ARM64EC") + file(GLOB onnxruntime_mlas_test_src CONFIGURE_DEPENDS + "${TEST_SRC_DIR}/mlas/unittest/*.h" + "${TEST_SRC_DIR}/mlas/unittest/*.cpp" ) - endif() - target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT} - ${CMAKE_CURRENT_BINARY_DIR}) - target_link_libraries(onnxruntime_mlas_test PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) - if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo) - endif() - if(NOT WIN32) - target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS}) - endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs}) - endif() - - if(WIN32) - target_link_libraries(onnxruntime_mlas_test PRIVATE debug Dbghelp Advapi32) - endif() - if (onnxruntime_LINK_LIBATOMIC) - target_link_libraries(onnxruntime_mlas_test PRIVATE atomic) - endif() - target_link_libraries(onnxruntime_mlas_test PRIVATE Threads::Threads) - - set_target_properties(onnxruntime_mlas_test PROPERTIES FOLDER "ONNXRuntimeTest") - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") - else() - set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") + onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src}) + if(MSVC) + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>" + "$<$>:/wd26409>") + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" + "$<$>:/utf-8>") + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>" + "$<$>:/wd6326>") + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" + "$<$>:/wd26426>") endif() - endif() - + if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + set_target_properties(onnxruntime_mlas_test PROPERTIES + XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" + ) + endif() + target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT} + ${CMAKE_CURRENT_BINARY_DIR}) + target_link_libraries(onnxruntime_mlas_test PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) + if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo) + endif() + if(NOT WIN32) + target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS}) + endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") + target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs}) + endif() + if(WIN32) + target_link_libraries(onnxruntime_mlas_test PRIVATE debug Dbghelp Advapi32) + endif() + if (onnxruntime_LINK_LIBATOMIC) + target_link_libraries(onnxruntime_mlas_test PRIVATE atomic) + endif() + target_link_libraries(onnxruntime_mlas_test PRIVATE Threads::Threads) + set_target_properties(onnxruntime_mlas_test PROPERTIES FOLDER "ONNXRuntimeTest") + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") + else() + set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") + endif() + endif() +endif() # Training API Tests # Disabling training_api_test_trainer. CXXOPT generates a ton of warnings because of which nuget pipeline is failing. # TODO(askhade): Fix the warnings. diff --git a/cmake/patches/composable_kernel/Fix_Clang_Build.patch b/cmake/patches/composable_kernel/Fix_Clang_Build.patch index 02b30af9eef52..15844dd917744 100644 --- a/cmake/patches/composable_kernel/Fix_Clang_Build.patch +++ b/cmake/patches/composable_kernel/Fix_Clang_Build.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index b09da41a8..fca2bdf69 100644 +index 04674124c..12e8b8b00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,7 +19,7 @@ endif() @@ -48,7 +48,18 @@ index b09da41a8..fca2bdf69 100644 ## tidy include(EnableCompilerWarnings) -@@ -489,11 +466,3 @@ rocm_install(FILES +@@ -376,7 +353,9 @@ if(BUILD_DEV) + add_compile_options(-Werror -Weverything) + endif() + #add flags to reduce the size of binaries +-add_compile_options(-Oz -flto=thin) ++# -flto requires ORT to use a linker that support LTO and -flto flag shoud be passed to linker together. ++# add_compile_options(-Oz -flto=thin) ++add_compile_options(-Oz) + message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + + add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) +@@ -482,11 +461,3 @@ rocm_install(FILES set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RPM_PACKAGE_LICENSE "MIT") @@ -61,7 +72,7 @@ index b09da41a8..fca2bdf69 100644 - HEADER_ONLY -) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt -index a0478c9f0..1e7782cd4 100644 +index 9cb5d0e9a..141a46f3d 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -44,8 +44,14 @@ function(add_instance_library INSTANCE_NAME) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 395996f0fa4b9..268ee3960e75a 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -451,6 +451,8 @@ onnxruntime_add_static_library(winml_lib_api ${winml_lib_api_dir}/impl/TensorKindFrom.h ${winml_lib_api_dir}/impl/TensorMemoryBufferReference.h ${winml_lib_api_dir}/NumericData.cpp + ${winml_lib_api_dir}/HardwareCoreEnumerator.cpp + ${winml_lib_api_dir}/HardwareCoreEnumerator.h ${winml_lib_api_dir}/ImageFeatureDescriptor.cpp ${winml_lib_api_dir}/ImageFeatureDescriptor.h ${winml_lib_api_dir}/ImageFeatureValue.cpp diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj index 0c74a23204d4f..1d15383239baf 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj @@ -6,7 +6,7 @@ true - netstandard2.0 + netstandard2.0;netcoreapp3.1;net6.0 diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index 86b44a6784817..163a2b394c4ae 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -263,12 +263,16 @@ public ReadOnlyMemory GetStringElementAsMemory(int index) /// UTF-16 string instance public string GetStringElement(int index) { - var chars = GetStringTensorElementChars(index); - if (chars.Length == 0) + GetStringTensorElementBuffer((UIntPtr)index, out uint bytesLen, out IntPtr bufferPtr); + if (bytesLen == 0) { return string.Empty; } - return new string(chars); + + unsafe + { + return Encoding.UTF8.GetString((byte*)bufferPtr.ToPointer(), (int)bytesLen); + } } diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9c31978c66486..131db5d8d9b37 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1599,14 +1599,14 @@ This version of the operator has been available since version 1 of the 'com.micr #### Inputs (1 - ∞)
-
inputs (variadic) : T
+
inputs (variadic, heterogeneous) : T
List of tensors for inputs
#### Outputs (1 - ∞)
-
outputs (variadic) : T
+
outputs (variadic, heterogeneous) : T
One or more outputs, list of tensors for outputs
@@ -2385,7 +2385,7 @@ This version of the operator has been available since version 1 of the 'com.micr Group Query Self/Cross Attention. - Supports different number of heads for q and kv. + Supports different number of heads for q and kv. Only supports causal or local attention. #### Version @@ -2396,6 +2396,8 @@ This version of the operator has been available since version 1 of the 'com.micr
kv_num_heads : int (required)
Number of attention heads for k and v
+
local_window_size : int
+
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
scale : float
@@ -2647,8 +2649,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(float), tensor(float16)
-
Constrain input and output types to float/half_float tensors.
+
T1 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float/half_float/brain_float tensors.
T2 : tensor(uint8)
Constrain quantized weight types to uint8.
@@ -2822,6 +2824,8 @@ This version of the operator has been available since version 1 of the 'com.micr
size of each input feature
N : int (required)
size of each output feature
+
accuracy_level : int
+
The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) (default unset). It is used to control how input A is quantized or downcast internally while doing computation, for example: 0 means input A will not be quantized or downcast while doing computation. 4 means input A can be quantized with the same block_size to int8 internally from type T1.
bits : int (required)
number of bits used for weight quantization (default 4)
block_size : int (required)
@@ -5021,7 +5025,7 @@ This version of the operator has been available since version 1 of the 'com.micr
input : T
-
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
@@ -5034,7 +5038,7 @@ This version of the operator has been available since version 1 of the 'com.micr
output : T
-
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
tensor with same shape as input.
#### Type Constraints diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md index e9ceae00a684d..97f7e7ff2c14b 100644 --- a/docs/Memory_Optimizer.md +++ b/docs/Memory_Optimizer.md @@ -17,74 +17,149 @@ Classical scenarios include: Not all models and recipes need this optimizer technique. Imagine if your training recipe uses a batch size 6 (GPU compute and memory are fully saturated), and you don't need bump it to 8 to maintain a fixed global batch size. Enabling recompute maybe not bring better throughput on batch size 8 than the original batch size 6. -## Quick trial +## Usage -1. Make sure ONNX Runtime training wheel is installed and correctly configured. -2. Integrate models using `ORTModule`, be noted log_level should be equal to or lower than DEVINFO. - > ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.DEVINFO)) -3. Run the training as usual and redirect all outputs into the log file; then stop it after training a few steps. -4. Check the logging file, and search "Summary", you could find something like this: + +Make sure ONNX Runtime training wheel is installed and correctly configured. +Integrate models using `ORTModule`. +```diff + model = build_model() + ++ from onnxruntime.training.ortmodule import ORTModule ++ model = ORTModule(model) +``` + +There are two modes to enable the memory optimizations: +- Aggressively Recompute All Within Each Transformer Layer, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. This will recompute all detected subgraphs within each Transformer Attention+MLP layer. It is easy to enable, but be noted this recompute plan may NOT be the best one. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected. +- User Specified Subgraph Recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=,,...`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans. + +### Mode 1 - Simple Usage (Aggressively Recompute All Within Each Transformer Layer) + + +1. Set memory optimization level to be TRANSFORMER_LAYERWISE_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1` +2. Run the training as usual; check the logs, you could find something like this if the current log level <= LogLevel.INFO: + ``` + Memory Optimizer : ON : Memory Optimization Level: [TRANSFORMER_LAYERWISE_RECOMPUTE], Optimization Config: [Reshape+Where+:1:-1,BiasSoftmax+:1:-1,Cast+:1:-1,BiasGelu+:1:-1,FusedMatMul+:1:-1,Add+:1:-1,Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1] + Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) + - Plan 1 : ON : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 2 : ON : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 3 : ON : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 4 : ON : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 5 : ON : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 6 : ON : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 7 : ON : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` - MemoryOptimizer Summary: - User config: - - ================================= - ########Recompute######## - Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 - -------------------------------- - Subgraph: FastGelu+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24 - ================================= - ########RecomputeWithCompromise######## - Subgraph: Cast+Where+Softmax+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24 - -------------------------------- - ================================= +3. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case. + + +### Mode 2 - Advanced Usage (User Selected Subgraph Recompute) + +1. Be noted `ORTMODULE_MEMORY_OPT_LEVEL` is by default be 0. Run the training as usual; then stop it after training a few steps. +2. Check the logs, you could find something like this if the current log level <= LogLevel.INFO:: ``` -5. As shown above, 'Subgraph' shows 1) a string representative for a re-computable subgraph; and 2) current status of memory optimization. All are disabled for recompute in this case. -6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In below example, 12 FastGelu related subgraphs are allowed to recompute. -`FastGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `12` means the initial 12 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed. + Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=,,... + Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) + - Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 3 : OFF : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 4 : OFF : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 5 : OFF : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 6 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 7 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` - export ORTMODULE_MEMORY_OPT_CONFIG="FastGelu+:1:12" +3. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case. +4. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraphs to do recompute. + ```bash + # Use comma as a separator for enabling more than one subgraphs. + export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:1" + # Explanation: + # > BiasGelu+ is the subgraph string representative; + # > 1 in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled) + # > The last 1 means the initial 1 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed. + ``` -7. Then run the training again, you will see logs like this: +5. Then run the training again, and you will see logs like this: ``` - MemoryOptimizer Summary: - User config: - **FastGelu+:1:12** - ================================= - ########Recompute######## - Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 - -------------------------------- - Subgraph: FastGelu+ - OptimizationType: **Recompute (requested_count=12, actual applied_count=12)** - Patterns: - PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24 - ================================= - ########RecomputeWithCompromise######## - Subgraph: Cast+Where+Softmax+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24 - -------------------------------- - ================================= + Memory Optimizer : ON : Memory Optimization Level: [USER_SPECIFIED], Optimization Config: [BiasGelu+:1:-1] + Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) + - Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 3 : OFF : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 4 : ON : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 5 : OFF : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 6 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 7 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` -8. You may need iterate few times on step 6 and 7 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well. +6. You may need iterate a few times on step 4 and 5 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well. + +## Optimization Configuration + +The basic optimization unit is represented with a unique `cluster id`, for example `BiasGelu+` is one `cluster id`. +Following `cluster id` is the `optimization strategy`: 0 - none, 1 - recompute, 2 - recompute with compromised memory saving. +Following `optimization strategy` is the `request count` to apply the given optimization. Using `-1` to apply all. This would give user a bit more flexibility to avoid unnecessary memory saving. + +### Compromised Recompute + +If you check the above logs, there is a config `Cast+:2:-1`, `2` indicates it's a recomputation than can save part of the stashed activation size, not all. Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it. + +## Dev Notes + +### Memory Optimization Debug Infos + +Using following log level +> ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.DEVINFO)) + +Besides the logs shown in `LogLevel.INFO`, you can also see different node patterns that can apply different optimization options. + +The way we get the table: +- For a specific node, it might has different optimization options, we [generates](../orttraining/orttraining/core/optimizer/memory_optimizer/common.h#L124C26-L124C26) a hash (called `Node Cluster ID`) for the node according to all available optimization options. +- Map all nodes having same `Node Cluster ID` in buckets, each bucket is displayed as one row. -## Compromised Recompute +``` +MemoryInsight Summary - User config: not provided +=========================================================================================================================================== +|Freq | Memory Optimization Opportunities (Clustered by node-level activation patterns) | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| +|6 |For each row options are mutually exclusive, only one of them can be enabled. | +| | | +| |>>Option 1 : Recompute subgraph FusedMatMul+Add+Reshape+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1 | +| | Stashed Activations: | +| | - ReuseFreq : Output 0(6), | +| | - Output 0 : [((inputs_input_ids_dim0)*(inputs_input_ids_dim1)*(32)*(240))], byte/elem: 2, 100% saved | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| +|5 |For each row options are mutually exclusive, only one of them can be enabled. | +| | | +| |>>Option 1 : Recompute subgraph FusedMatMul+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 | +| | Stashed Activations: | +| | - Output 0 : [((inputs_input_ids_dim0)*(inputs_input_ids_dim1)*(10240))], byte/elem: 2, 100% saved | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| +|5 |For each row options are mutually exclusive, only one of them can be enabled. | +| | | +| |>>Option 1 : Recompute subgraph Cast+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 | +| | Stashed Activations: | +| | - Output 0 : [((inputs_input_ids_dim0)*(32)*(inputs_input_ids_dim1)*(inputs_input_ids_dim1))], byte/elem: 2, 100% saved | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| +|1 |For each row options are mutually exclusive, only one of them can be enabled. | +| | | +| |>>Option 1 : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 | +| | Stashed Activations: | +| | - Output 0 : [((inputs_input_ids_dim0)*(1)*(1)*(inputs_input_ids_dim1))], byte/elem: 4, 100% saved | +| | | +| |>>Option 2 : RecomputeWithCompromise subgraph Cast+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:2:-1 | +| | Stashed Activations: | +| | - Output 0 : [((inputs_input_ids_dim0)*(1)*(1)*(inputs_input_ids_dim1))], byte/elem: 4, 50% saved | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| -If you check the above logs, there is a separate section called "RecomputeWithCompromise". Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it. +``` ## Notes -The feature is in experimental stage, we will tune and refine it according to real use cases. +The feature is in the experimental stage, we will tune and refine it according to real use cases. diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 12733c3551704..bede16204d420 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -146,7 +146,6 @@ Check [DebugOptions implementation](../orttraining/orttraining/python/training/o export ORTMODULE_ONNX_OPSET_VERSION=14 ``` - #### ORTMODULE_FALLBACK_POLICY - **Feature Area**: *ORTMODULE/FallbackToPytorch* @@ -155,7 +154,6 @@ Check [DebugOptions implementation](../orttraining/orttraining/python/training/o export ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE" ``` - #### ORTMODULE_LOG_LEVEL - **Feature Area**: *ORTMODULE/DebugOptions* @@ -182,7 +180,6 @@ The output directory of the onnx models by default is set to the current working > On the other hand, if the wrapped computation graph is small, it is reasonable to allow it. > Overall users should be aware that ORT performance boost might be trivial when they explicitly allow it. - #### ORTMODULE_ENABLE_CUSTOM_AUTOGRAD - **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)* @@ -199,8 +196,6 @@ The output directory of the onnx models by default is set to the current working enable_custom_autograd_support(False) ``` - - #### ORTMODULE_ENABLE_COMPUTE_OPTIMIZER - **Feature Area**: *ORTMODULE/Optimizations* @@ -269,6 +264,35 @@ data sparsity based performance optimizations. unset ORTMODULE_CACHE_DIR # Disable ``` +#### ORTMODULE_USE_EFFICIENT_ATTENTION + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and falling back to PyTorch's efficient_attention ATen kernel for execution. NOTE that it requires torch's version is 2.1.1 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually. + + ```bash + export ORTMODULE_USE_EFFICIENT_ATTENTION=1 + ``` + +#### ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the module deep copy when preparing output data which will be used by ONNX export. +A classical usage of disabling the deep copy: when the deep copy before module export bring the memory peak, then we should disable it and have a try. + + ```bash + export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=1 # Enable + export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0 # Disable + ``` + +#### ORTMODULE_MEMORY_OPT_LEVEL + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. Setting the level to be 0 means all detected subgraphs with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. When level is not 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details. + + ```bash + export ORTMODULE_MEMORY_OPT_LEVEL=0 + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* @@ -370,6 +394,30 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training export ORTMODULE_USE_TRITON=1 ``` +#### ORTMODULE_TRITON_CONFIG_FILE + +- **Feature Area**: *ORTMODULE/TritonOp* +- **Description**: Triton codegen currently supported some Ops such as some elementwise Ops and some reduction Ops. If Triton optimization is enabled, all these supported Ops will be optimized by default if possible. User can provide a customized JSON config file to control which Ops to optimize and how to optimize them. Below is a sample of config JSON. For each Op, Opset version list and domain is needed. Currently "conditions" field can be used to control axis/axes attribute or input, by specify the real value, or "single" means it contains only one dimension, or "constant" means it must be constant tensor. Save the JSON as a file somewhere and assign its path to below env variable to enable the customized config. + + ```json + { + "ops": { + "Add": {"versions": [13, 14]}, + "Sub": {"versions": [13, 14]}, + "Identity": {"versions": [13], "is_no_op": True}, + "ReduceSum": {"versions": [13], "conditions": {"axes": "[-1]"}}, + "Softmax": {"versions": [13]}, + "SoftmaxGrad_13": {"domain": "com.microsoft", "versions": [1]} + }, + "initializer": "scalar", + "min_nodes": 2 + } + ``` + + ```bash + export ORTMODULE_TRITON_CONFIG_FILE=triton_config.json + ``` + #### ORTMODULE_ENABLE_TUNING - **Feature Area**: *ORTMODULE/TritonOp* @@ -397,6 +445,15 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training export ORTMODULE_TUNING_RESULTS_PATH=/tmp/tuning_results ``` +#### ORTMODULE_USE_FLASH_ATTENTION + +- **Feature Area**: *ORTMODULE/TritonOp* +- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and using Flash Attention's Triton version as the kernel. NOTE that it requires ORTMODULE_USE_TRITON to be enabled, and CUDA device capability is 8.0 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually. + + ```bash + export ORTMODULE_USE_FLASH_ATTENTION=1 + ``` + #### ORTMODULE_TRITON_DEBUG - **Feature Area**: *ORTMODULE/TritonOp* diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 26b5ebbdbec36..1ce9b3254d91f 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -80,7 +80,8 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| -|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|||[17, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| @@ -373,7 +374,7 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string)| +|SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)| |Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| |Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -840,7 +841,7 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/docs/python/api_summary.rst b/docs/python/api_summary.rst index cecd62aff15c4..092b42010a5c6 100644 --- a/docs/python/api_summary.rst +++ b/docs/python/api_summary.rst @@ -274,6 +274,77 @@ SessionOptions .. autoclass:: onnxruntime.SessionOptions :members: +.. autoclass:: onnxruntime.ExecutionMode + :members: + +.. autoclass:: onnxruntime.ExecutionOrder + :members: + +.. autoclass:: onnxruntime.GraphOptimizationLevel + :members: + +.. autoclass:: onnxruntime.OrtAllocatorType + :members: + +.. autoclass:: onnxruntime.OrtArenaCfg + :members: + +.. autoclass:: onnxruntime.OrtMemoryInfo + :members: + +.. autoclass:: onnxruntime.OrtMemType + :members: + +Functions +--------- + +Allocators +^^^^^^^^^^ + +.. autofunction:: onnxruntime.create_and_register_allocator + +.. autofunction:: onnxruntime.create_and_register_allocator_v2 + +Telemetry events +^^^^^^^^^^^^^^^^ + +.. autofunction:: onnxruntime.disable_telemetry_events + +.. autofunction:: onnxruntime.enable_telemetry_events + +Providers +^^^^^^^^^ + +.. autofunction:: onnxruntime.get_all_providers + +.. autofunction:: onnxruntime.get_available_providers + +Build, Version +^^^^^^^^^^^^^^ + +.. autofunction:: onnxruntime.get_build_info + +.. autofunction:: onnxruntime.get_version_string + +.. autofunction:: onnxruntime.has_collective_ops + +Device +^^^^^^ + +.. autofunction:: onnxruntime.get_device + +Logging +^^^^^^^ + +.. autofunction:: onnxruntime.set_default_logger_severity + +.. autofunction:: onnxruntime.set_default_logger_verbosity + +Random +^^^^^^ + +.. autofunction:: onnxruntime.set_seed + Data ---- @@ -298,6 +369,9 @@ IOBinding .. autoclass:: onnxruntime.IOBinding :members: +.. autoclass:: onnxruntime.SessionIOBinding + :members: + OrtDevice ^^^^^^^^^ diff --git a/include/onnxruntime/core/framework/tensor_shape.h b/include/onnxruntime/core/framework/tensor_shape.h index b3783696b8d78..82a1c1de83523 100644 --- a/include/onnxruntime/core/framework/tensor_shape.h +++ b/include/onnxruntime/core/framework/tensor_shape.h @@ -2,34 +2,17 @@ // Licensed under the MIT License. #pragma once -#include -#include + #include -#include #include -#include "core/common/gsl.h" -#include "onnxruntime_config.h" - -#ifndef DISABLE_ABSEIL -// Need to include abseil inlined_vector.h header directly here -// as hash tables cause CUDA 10.2 compilers to fail. inlined_vector.h is fine. -#ifdef _MSC_VER -#pragma warning(push) -// C4127: conditional expression is constant -#pragma warning(disable : 4127) -// C4324: structure was padded due to alignment specifier -// Usage of alignas causes some internal padding in places. -#pragma warning(disable : 4324) -#endif - -#include - -#ifdef _MSC_VER -#pragma warning(pop) -#endif -#endif // DISABLE_ABSEIL +#include +#include +#include +#include "core/common/gsl.h" +#include "core/common/inlined_containers_fwd.h" #include "core/common/span_utils.h" +#include "onnxruntime_config.h" namespace onnxruntime { #ifdef __GNUC__ @@ -41,18 +24,10 @@ namespace onnxruntime { constexpr size_t kTensorShapeSmallBufferElementsSize = 5; -#ifndef DISABLE_ABSEIL // Use this type to build a shape and then create TensorShape. -using TensorShapeVector = absl::InlinedVector; -#else -class TensorShapeVector : public std::vector { - using Base = std::vector; - - public: - using Base::Base; -}; - -#endif // DISABLE_ABSEIL +// We opt to re-use a common instantiation instead of a typedef with kTensorShapeSmallBufferElementsSize +// To reduce on binary size. +using TensorShapeVector = InlinedVector; inline TensorShapeVector ToShapeVector(const gsl::span& span) { TensorShapeVector out; @@ -194,9 +169,7 @@ class TensorShape { friend struct ProviderHostImpl; // So that the shared provider interface can access Allocate }; -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif + // operator<< to nicely output to a stream std::ostream& operator<<(std::ostream& out, const TensorShape& shape); diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 7e59aad80cc47..9b26ba914c7dd 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -55,4 +55,7 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path"; constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point"; +// For Priority based graph topology sorting. +constexpr const char* kBackwardNodeAttributeName = "__backwardpass"; + } // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index fe0734c51f807..22827d43b200f 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -668,7 +668,7 @@ class Node { The Graph representation containing the graph inputs and outputs, the Node instances, and the edges connecting the nodes. */ -class Graph { +class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve existing data member order for readability public: /** Gets the Graph name. */ const std::string& Name() const noexcept; diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index cf3ddc3f125f9..7d7f05193f486 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -37,9 +37,13 @@ enum OrtDmlPerformancePreference { }; enum OrtDmlDeviceFilter : uint32_t { +#ifdef ENABLE_NPU_ADAPTER_ENUMERATION Any = 0xffffffff, Gpu = 1 << 0, Npu = 1 << 1, +#else + Gpu = 1 << 0, +#endif }; inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index cddad732104ed..c41700453a73b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3598,6 +3598,7 @@ struct OrtApi { * "qnn_context_cache_path": explicitly provide the QNN context cache file. Default to model_file.onnx.bin if not provided. * "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off. * "rpc_control_latency": QNN RPC control latency. + * "vtcm_mb": QNN VTCM size in MB. default to 0(not set). * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance", * "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". * "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the ONNX skeleton model. diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index 443710884743a..0c0af16d4e20c 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -399,6 +399,15 @@ struct TensorArray : public ArgBase { using Variadic = TensorArray; +/* +Note: +OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core. +The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so: +1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierachy. +2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp, + hence memory could still be recycled properly. +Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety. +*/ struct OrtLiteCustomOp : public OrtCustomOp { using ConstOptionalFloatTensor = std::optional&>; using OptionalFloatTensor = std::optional>; @@ -774,10 +783,13 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtLiteCustomOp(const char* op_name, const char* execution_provider, - int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), - execution_provider_(execution_provider), - start_ver_(start_ver), - end_ver_(end_ver) { + ShapeInferFn shape_infer_fn, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + shape_infer_fn_(shape_infer_fn), + start_ver_(start_ver), + end_ver_(end_ver) { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; @@ -858,8 +870,13 @@ struct OrtLiteCustomOp : public OrtCustomOp { std::vector input_types_; std::vector output_types_; + ShapeInferFn shape_infer_fn_ = {}; + int start_ver_ = 1; int end_ver_ = MAX_CUSTOM_OP_END_VER; + + void* compute_fn_ = {}; + void* compute_fn_return_status_ = {}; }; //////////////////////////// OrtLiteCustomFunc //////////////////////////////// @@ -891,9 +908,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { ComputeFn compute_fn, ShapeInferFn shape_infer_fn = {}, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), - compute_fn_(compute_fn), - shape_infer_fn_(shape_infer_fn) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_ = reinterpret_cast(compute_fn); ParseArgs(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { @@ -905,7 +921,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->compute_fn_ = static_cast(this_)->compute_fn_; + auto me = static_cast(this_); + kernel->compute_fn_ = reinterpret_cast(me->compute_fn_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast(this_); @@ -931,9 +948,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { ComputeFnReturnStatus compute_fn_return_status, ShapeInferFn shape_infer_fn = {}, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), - compute_fn_return_status_(compute_fn_return_status), - shape_infer_fn_(shape_infer_fn) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_return_status_ = reinterpret_cast(compute_fn_return_status); ParseArgs(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { @@ -945,7 +961,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->compute_fn_return_status_ = static_cast(this_)->compute_fn_return_status_; + auto me = static_cast(this_); + kernel->compute_fn_return_status_ = reinterpret_cast(me->compute_fn_return_status_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast(this_); @@ -965,10 +982,6 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { }; } } - - ComputeFn compute_fn_ = {}; - ComputeFnReturnStatus compute_fn_return_status_ = {}; - ShapeInferFn shape_infer_fn_ = {}; }; // struct OrtLiteCustomFunc /////////////////////////// OrtLiteCustomStruct /////////////////////////// @@ -1007,7 +1020,7 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { OrtLiteCustomStruct(const char* op_name, const char* execution_provider, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) { SetCompute(&CustomOp::Compute); OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 831def24e4f5e..a94973b2cc5d7 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -80,17 +80,17 @@ static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = #ifdef ENABLE_TRAINING // Specifies a list of op types for memory footprint reduction. // The value should be a ","-delimited list of pair of -// . +// . // For example, "Gelu+Cast+:1:0,Dropout+:1:1". // A valid "subgraph string" should be one subgraph representation output by ORT graph transformations. // "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute. // "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving" // the memory. -static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.enable_memory_optimizer"; +static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.memory_optimizer_config"; -// Specifies the level for detecting subgraphs for memory footprint reduction. -// The value should be an integer. The default value is 0. -static const char* const kOrtSessionOptionsMemoryOptimizerProbeLevel = "optimization.enable_memory_probe_recompute_level"; +// Specifies the config for detecting subgraphs for memory footprint reduction. +// The value should be a string contains int separated using commas. The default value is "0:0". +static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config"; #endif // Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0". diff --git a/js/.eslintrc.js b/js/.eslintrc.js index fd30cb96a5bd0..0bf47c5264f61 100644 --- a/js/.eslintrc.js +++ b/js/.eslintrc.js @@ -5,10 +5,18 @@ module.exports = { root: true, - ignorePatterns: ['**/*.js', 'ort-schema/', 'common/test/type-tests/', 'test/data/', 'node_modules/', 'dist/'], + ignorePatterns: [ + '**/*.js', + 'node_modules/', + 'ort-schema/', + 'common/test/type-tests/', + 'web/types.d.ts', + 'test/data/', + 'dist/', + ], env: { 'es6': true }, parser: '@typescript-eslint/parser', - parserOptions: { 'project': 'tsconfig.json', 'sourceType': 'module' }, + parserOptions: { 'project': true, 'sourceType': 'module' }, plugins: ['@typescript-eslint', 'prefer-arrow', 'header', 'import', 'unicorn', 'jsdoc'], rules: { 'unicorn/filename-case': 'error', @@ -144,15 +152,56 @@ module.exports = { 'no-unused-expressions': 'off', } }, { - files: ['web/lib/**/*.ts'], - excludedFiles: 'web/lib/wasm/proxy-worker/**/*', - parserOptions: { 'project': 'web/tsconfig.json' }, - rules: { - 'no-underscore-dangle': 'off', + files: ['web/lib/**/*.ts'], rules: { + 'no-underscore-dangle': ['error', { + 'allow': [ + '_free', + '_malloc', + '_JsepGetNodeName', + '_JsepOutput', + '_OrtAddFreeDimensionOverride', + '_OrtAddRunConfigEntry', + '_OrtAddSessionConfigEntry', + '_OrtAppendExecutionProvider', + '_OrtBindInput', + '_OrtBindOutput', + '_OrtClearBoundOutputs', + '_OrtCreateBinding', + '_OrtCreateRunOptions', + '_OrtCreateSession', + '_OrtCreateSessionOptions', + '_OrtCreateTensor', + '_OrtEndProfiling', + '_OrtFree', + '_OrtGetInputName', + '_OrtGetInputOutputCount', + '_OrtGetLastError', + '_OrtGetOutputName', + '_OrtGetTensorData', + '_OrtInit', + '_OrtReleaseBinding', + '_OrtReleaseRunOptions', + '_OrtReleaseSession', + '_OrtReleaseSessionOptions', + '_OrtReleaseTensor', + '_OrtRun', + '_OrtRunWithBinding', + '_OrtTrainingCopyParametersFromBuffer', + '_OrtTrainingCopyParametersToBuffer', + '_OrtTrainingCreateSession', + '_OrtTrainingEvalStep', + '_OrtTrainingGetModelInputOutputCount', + '_OrtTrainingGetModelInputOutputName', + '_OrtTrainingGetParametersSize', + '_OrtTrainingLazyResetGrad', + '_OrtTrainingLoadCheckpoint', + '_OrtTrainingOptimizerStep', + '_OrtTrainingReleaseCheckpoint', + '_OrtTrainingReleaseSession', + '_OrtTrainingRunTrainStep' + ] + }] } - }, { - files: ['web/lib/wasm/proxy-worker/**/*.ts'], - parserOptions: { 'project': 'web/lib/wasm/proxy-worker/tsconfig.json' }, }, { files: ['web/lib/onnxjs/**/*.ts'], rules: { // TODO: those rules are useful. should turn on them in future (webgl refactor) @@ -164,6 +213,7 @@ module.exports = { 'import/no-internal-modules': 'off', 'prefer-arrow/prefer-arrow-functions': 'off', 'no-param-reassign': 'off', + 'no-underscore-dangle': 'off', 'guard-for-in': 'off' } }, { diff --git a/js/README.md b/js/README.md index 7e6681e6bd897..1662de6d4ac78 100644 --- a/js/README.md +++ b/js/README.md @@ -344,13 +344,13 @@ From ORT v1.13 onwards the 'full' ONNX Runtime package is used. It supports both Full build: ```sh - python tools/ci_build/github/apple/build_ios_framework.py tools/ci_build/github/apple/default_full_ios_framework_build_settings.json --config Release + python tools/ci_build/github/apple/build_apple_framework.py tools/ci_build/github/apple/default_full_apple_framework_build_settings.json --config Release ``` Reduced size build: ```sh - python tools/ci_build/github/apple/build_ios_framework.py tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json --config MinSizeRel --include_ops_by_config --enable_reduced_operator_type_support + python tools/ci_build/github/apple/build_apple_framework.py tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json --config MinSizeRel --include_ops_by_config --enable_reduced_operator_type_support ``` The build creates `Headers`, `LICENSE`, and `onnxruntime.xcframework` in `build/iOS_framework/framework_out` directory. From `framework_out` directory, create an archive file named `onnxruntime-c.zip` for a full build or `onnxruntime-mobile-c.zip` for a reduced size build and copy to `/js/react_native/local_pods` directory. diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index dd04ef3f15997..5460ae086fc2f 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -45,12 +45,21 @@ export interface InferenceSessionHandler extends SessionHandler { * @ignore */ export interface TrainingSessionHandler extends SessionHandler { + readonly evalInputNames: readonly string[]; + readonly evalOutputNames: readonly string[]; + + lazyResetGrad(): Promise; runTrainStep( feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise; + runOptimizerStep(options: InferenceSession.RunOptions): Promise; + runEvalStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise; + getParametersSize(trainableOnly: boolean): Promise; loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; - getContiguousParameters(trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; } /** diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 76575ef7b9368..0cded7e5edbcb 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -92,11 +92,48 @@ export declare namespace Env { async?: boolean; } + export interface WebGpuProfilingDataV1TensorMetadata { + dims: readonly number[]; + dataType: string; + } + export interface WebGpuProfilingDataV1 { + version: 1; + inputsMetadata: readonly WebGpuProfilingDataV1TensorMetadata[]; + outputsMetadata: readonly WebGpuProfilingDataV1TensorMetadata[]; + kernelId: number; + kernelType: string; + kernelName: string; + startTime: number; + endTime: number; + } + + export type WebGpuProfilingData = WebGpuProfilingDataV1; + export interface WebGpuFlags { /** * Set or get the profiling mode. + * + * @deprecated Use `env.webgpu.profiling.mode` instead. If `env.webgpu.profiling.mode` is set, this property will be + * ignored. */ profilingMode?: 'off'|'default'; + /** + * Set or get the profiling configuration. + */ + profiling?: { + /** + * Set or get the profiling mode. + * + * @defaultValue `'off'` + */ + mode?: 'off'|'default'; + + /** + * Set or get a callback function when a profiling data is received. If not set, the profiling data will be + * printed to console. + */ + ondata?: (data: WebGpuProfilingData) => void; + }; /** * Get the device for WebGPU. * diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index ee6d26b22b1f6..23bd4421ae672 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -18,18 +18,37 @@ const noBackendErrMsg: string = 'Training backend could not be resolved. ' + 'Make sure you\'re using the correct configuration & WebAssembly files.'; export class TrainingSession implements TrainingSessionInterface { - private constructor(handler: TrainingSessionHandler) { + private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) { this.handler = handler; + this.hasOptimizerModel = hasOptimizerModel; + this.hasEvalModel = hasEvalModel; } private handler: TrainingSessionHandler; + private hasOptimizerModel: boolean; + private hasEvalModel: boolean; - get inputNames(): readonly string[] { + get trainingInputNames(): readonly string[] { return this.handler.inputNames; } - get outputNames(): readonly string[] { + get trainingOutputNames(): readonly string[] { return this.handler.outputNames; } + get evalInputNames(): readonly string[] { + if (this.hasEvalModel) { + return this.handler.evalInputNames; + } else { + throw new Error('This training session has no evalModel loaded.'); + } + } + get evalOutputNames(): readonly string[] { + if (this.hasEvalModel) { + return this.handler.evalOutputNames; + } else { + throw new Error('This training session has no evalModel loaded.'); + } + } + static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): Promise { const evalModel: string|Uint8Array = trainingOptions.evalModel || ''; @@ -43,7 +62,7 @@ export class TrainingSession implements TrainingSessionInterface { if (backend.createTrainingSessionHandler) { const handler = await backend.createTrainingSessionHandler( trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); - return new TrainingSession(handler); + return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel); } else { throw new Error(noBackendErrMsg); } @@ -53,13 +72,18 @@ export class TrainingSession implements TrainingSessionInterface { * Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from * the given parameters to SessionHandler.FetchesType and RunOptions. * + * @param inputNames the feeds object is checked that they contain all input names in the provided list of input + * names. + * @param outputNames the fetches object is checked that their keys match up with valid names in the list of output + * names. * @param feeds the required input * @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object * @param arg2 optional RunOptions object. * @returns */ - typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): - [SessionHandler.FetchesType, RunOptions] { + typeNarrowingForRunStep( + inputNames: readonly string[], outputNames: readonly string[], feeds: FeedsType, arg1?: FetchesType|RunOptions, + arg2?: RunOptions): [SessionHandler.FetchesType, RunOptions] { const fetches: {[name: string]: OnnxValue|null} = {}; let options: RunOptions = {}; // check inputs @@ -88,7 +112,7 @@ export class TrainingSession implements TrainingSessionInterface { if (typeof name !== 'string') { throw new TypeError('\'fetches\' must be a string array or an object.'); } - if (this.outputNames.indexOf(name) === -1) { + if (outputNames.indexOf(name) === -1) { throw new RangeError(`'fetches' contains invalid output name: ${name}.`); } fetches[name] = null; @@ -104,7 +128,7 @@ export class TrainingSession implements TrainingSessionInterface { // if any output name is present and its value is valid OnnxValue, we consider it fetches let isFetches = false; const arg1Keys = Object.getOwnPropertyNames(arg1); - for (const name of this.outputNames) { + for (const name of outputNames) { if (arg1Keys.indexOf(name) !== -1) { const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name]; if (v === null || v instanceof Tensor) { @@ -130,7 +154,7 @@ export class TrainingSession implements TrainingSessionInterface { } // check if all inputs are in feed - for (const name of this.inputNames) { + for (const name of inputNames) { if (typeof feeds[name] === 'undefined') { throw new Error(`input '${name}' is missing in 'feeds'.`); } @@ -138,7 +162,7 @@ export class TrainingSession implements TrainingSessionInterface { // if no fetches is specified, we use the full output names list if (isFetchesEmpty) { - for (const name of this.outputNames) { + for (const name of outputNames) { fetches[name] = null; } } @@ -168,20 +192,58 @@ export class TrainingSession implements TrainingSessionInterface { return returnValue; } + async lazyResetGrad(): Promise { + await this.handler.lazyResetGrad(); + } + runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { - const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2); + const [fetches, options] = + this.typeNarrowingForRunStep(this.trainingInputNames, this.trainingOutputNames, feeds, arg1, arg2); const results = await this.handler.runTrainStep(feeds, fetches, options); return this.convertHandlerReturnTypeToMapOfTensors(results); } - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + async runOptimizerStep(options?: InferenceSession.RunOptions|undefined): Promise { + if (this.hasOptimizerModel) { + await this.handler.runOptimizerStep(options || {}); + } else { + throw new Error('This TrainingSession has no OptimizerModel loaded.'); + } + } + + runEvalStep(feeds: FeedsType, options?: RunOptions|undefined): Promise; + runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions|undefined): Promise; + async runEvalStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { + if (this.hasEvalModel) { + const [fetches, options] = + this.typeNarrowingForRunStep(this.evalInputNames, this.evalOutputNames, feeds, arg1, arg2); + const results = await this.handler.runEvalStep(feeds, fetches, options); + return this.convertHandlerReturnTypeToMapOfTensors(results); + } else { + throw new Error('This TrainingSession has no EvalModel loaded.'); + } + } + + async getParametersSize(trainableOnly = true): Promise { + return this.handler.getParametersSize(trainableOnly); + } + + async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise { + const paramsSize = await this.getParametersSize(trainableOnly); + // checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number + // of parameters + if (array.length !== 4 * paramsSize) { + throw new Error( + 'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' + + 'the model. Please use getParametersSize method to check.'); + } + return this.handler.loadParametersBuffer(array, trainableOnly); } - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + async getContiguousParameters(trainableOnly = true): Promise { + return this.handler.getContiguousParameters(trainableOnly); } async release(): Promise { diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index 0967d79b33434..e54aed90e702c 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {InferenceSession} from './inference-session.js'; +import {OnnxValue} from './onnx-value.js'; import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js'; /* eslint-disable @typescript-eslint/no-redeclare */ @@ -21,6 +22,12 @@ export declare namespace TrainingSession { export interface TrainingSession { // #region run() + /** + * Lazily resets the gradients of all trainable parameters to zero. Should happen after the invocation of + * runOptimizerStep. + */ + lazyResetGrad(): Promise; + /** * Run TrainStep asynchronously with the given feeds and options. * @@ -38,7 +45,7 @@ export interface TrainingSession { * @param feeds - Representation of the model input. * @param fetches - Representation of the model output. * detail. - * @param options - Optional. A set of options that controls the behavior of model inference. + * @param options - Optional. A set of options that controls the behavior of model training. * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. */ @@ -46,24 +53,68 @@ export interface TrainingSession { feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, options?: InferenceSession.RunOptions): Promise; + /** + * Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model. + * + * @param options - Optional. A set of options that controls the behavior of model optimizing. + */ + runOptimizerStep(options?: InferenceSession.RunOptions): Promise; + + /** + * Run a single eval step with the given inputs and options using the eval model. + * + * @param feeds - Representation of the model input. + * @param options - Optional. A set of options that controls the behavior of model eval step. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding + values. + */ + runEvalStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions): + Promise; + + /** + * Run a single eval step with the given inputs and options using the eval model. + * + * @param feeds - Representation of the model input. + * @param fetches - Representation of the model output. + * detail. + * @param options - Optional. A set of options that controls the behavior of model eval step. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding + values. + */ + runEvalStep( + feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions): Promise; + // #endregion // #region copy parameters + /** - * Copies from a buffer containing parameters to the TrainingSession parameters. + * Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of + * the parameters) elements of all the parameters in the training state. * - * @param buffer - buffer containing parameters - * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. + * @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true. + */ + getParametersSize(trainableOnly: boolean): Promise; + + /** + * Copies parameter values from the given array to the training state. Currently, only supporting models with + * parameters of type Float32. + * + * @param buffer - Float32 buffer containing parameters converted to a Uint8Array. + * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true. */ loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; /** - * Copies from the TrainingSession parameters to a buffer. + * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning. + * Currently, only supporting models with parameters of type Float32. * - * @param trainableOnly - True if trainable parameters only to be copied, false othrwise. - * @returns A promise that resolves to a buffer of the requested parameters. + * @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters + * for which requires_grad is set to true. Default value is true. + * @returns A promise that resolves to a Float32 OnnxValue of the requested parameters. */ - getContiguousParameters(trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; // #endregion // #region release() @@ -77,14 +128,25 @@ export interface TrainingSession { // #region metadata /** - * Get input names of the loaded model. + * Get input names of the loaded training model. + */ + readonly trainingInputNames: readonly string[]; + + /** + * Get output names of the loaded training model. */ - readonly inputNames: readonly string[]; + readonly trainingOutputNames: readonly string[]; /** - * Get output names of the loaded model. + * Get input names of the loaded eval model. Is an empty array if no eval model is loaded. */ - readonly outputNames: readonly string[]; + readonly evalInputNames: readonly string[]; + + /** + * Get output names of the loaded eval model. Is an empty array if no eval model is loaded. + */ + readonly evalOutputNames: readonly string[]; + // #endregion } diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index 5f5ad49a2dea8..e8eb0e9babf5a 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -20,7 +20,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { } async dispose(): Promise { - return Promise.resolve(); + this.#inferenceSession.dispose(); } readonly inputNames: string[]; diff --git a/js/node/lib/binding.ts b/js/node/lib/binding.ts index 8a0ce89abfa64..54b5767139904 100644 --- a/js/node/lib/binding.ts +++ b/js/node/lib/binding.ts @@ -28,6 +28,8 @@ export declare namespace Binding { readonly outputNames: string[]; run(feeds: FeedsType, fetches: FetchesType, options: RunOptions): ReturnType; + + dispose(): void; } export interface InferenceSessionConstructor { diff --git a/js/node/package-lock.json b/js/node/package-lock.json index e8968bafc4a9f..c1cf8af4bb80e 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -22,7 +22,7 @@ "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", - "onnx-proto": "^8.0.1" + "protobufjs": "^7.2.4" } }, "../common": { @@ -97,12 +97,6 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "dev": true }, - "node_modules/@types/long": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", - "integrity": "sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==", - "dev": true - }, "node_modules/@types/minimist": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.2.tgz", @@ -528,9 +522,9 @@ "dev": true }, "node_modules/long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, "node_modules/lru-cache": { @@ -663,15 +657,6 @@ "node": "^12.13.0 || ^14.15.0 || >=16.0.0" } }, - "node_modules/onnx-proto": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-8.0.1.tgz", - "integrity": "sha512-ZpPTqp5dneh2bvavk/QpDsf20JJRArjqTkiMfshGmxR8ocjmfTk80fkW00FwLO7qRtybo9NPugcWQrumHYctLQ==", - "dev": true, - "dependencies": { - "protobufjs": "^6.11.2" - } - }, "node_modules/onnxruntime-common": { "resolved": "../common", "link": true @@ -690,9 +675,9 @@ } }, "node_modules/protobufjs": { - "version": "6.11.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz", - "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "dev": true, "hasInstallScript": true, "dependencies": { @@ -706,13 +691,11 @@ "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", - "@types/long": "^4.0.1", "@types/node": ">=13.7.0", - "long": "^4.0.0" + "long": "^5.0.0" }, - "bin": { - "pbjs": "bin/pbjs", - "pbts": "bin/pbts" + "engines": { + "node": ">=12.0.0" } }, "node_modules/proxy-from-env": { @@ -789,9 +772,9 @@ ] }, "node_modules/semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "dependencies": { "lru-cache": "^6.0.0" @@ -1070,12 +1053,6 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "dev": true }, - "@types/long": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", - "integrity": "sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==", - "dev": true - }, "@types/minimist": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.2.tgz", @@ -1413,9 +1390,9 @@ "dev": true }, "long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, "lru-cache": { @@ -1523,15 +1500,6 @@ "set-blocking": "^2.0.0" } }, - "onnx-proto": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-8.0.1.tgz", - "integrity": "sha512-ZpPTqp5dneh2bvavk/QpDsf20JJRArjqTkiMfshGmxR8ocjmfTk80fkW00FwLO7qRtybo9NPugcWQrumHYctLQ==", - "dev": true, - "requires": { - "protobufjs": "^6.11.2" - } - }, "onnxruntime-common": { "version": "file:../common", "requires": { @@ -1549,9 +1517,9 @@ } }, "protobufjs": { - "version": "6.11.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz", - "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "dev": true, "requires": { "@protobufjs/aspromise": "^1.1.2", @@ -1564,9 +1532,8 @@ "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", - "@types/long": "^4.0.1", "@types/node": ">=13.7.0", - "long": "^4.0.0" + "long": "^5.0.0" } }, "proxy-from-env": { @@ -1619,9 +1586,9 @@ "dev": true }, "semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "requires": { "lru-cache": "^6.0.0" diff --git a/js/node/package.json b/js/node/package.json index 0f8f0e9d2260c..8e591d8f46b9d 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -19,6 +19,7 @@ }, "scripts": { "buildr": "tsc && node ./script/build --config=RelWithDebInfo", + "preprepare": "node -e \"require('node:fs').copyFileSync('./node_modules/long/index.d.ts', './node_modules/long/umd/index.d.ts')\"", "prepare": "tsc --build script test .", "rebuild": "tsc && node ./script/build --rebuild", "rebuildd": "tsc && node ./script/build --rebuild --config=Debug", @@ -39,7 +40,7 @@ "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", - "onnx-proto": "^8.0.1" + "protobufjs": "^7.2.4" }, "main": "dist/index.js", "os": [ diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index c409fdc8895f7..1bbb6df1ce1c8 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -31,6 +31,7 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { Napi::Function func = DefineClass( env, "InferenceSession", {InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), InstanceMethod("run", &InferenceSessionWrap::Run), + InstanceMethod("dispose", &InferenceSessionWrap::Dispose), InstanceAccessor("inputNames", &InferenceSessionWrap::GetInputNames, nullptr, napi_default, nullptr), InstanceAccessor("outputNames", &InferenceSessionWrap::GetOutputNames, nullptr, napi_default, nullptr)}); @@ -45,7 +46,7 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { } InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo &info) - : Napi::ObjectWrap(info), initialized_(false), session_(nullptr), + : Napi::ObjectWrap(info), initialized_(false), disposed_(false), session_(nullptr), defaultRunOptions_(nullptr) {} Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) { @@ -53,6 +54,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) { Napi::HandleScope scope(env); ORT_NAPI_THROW_ERROR_IF(this->initialized_, env, "Model already loaded. Cannot load model multiple times."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); size_t argsLength = info.Length(); ORT_NAPI_THROW_TYPEERROR_IF(argsLength == 0, env, "Expect argument: model file path or buffer."); @@ -129,6 +131,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) { Napi::Value InferenceSessionWrap::GetInputNames(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); Napi::EscapableHandleScope scope(env); return scope.Escape(CreateNapiArrayFrom(env, inputNames_)); @@ -137,6 +140,7 @@ Napi::Value InferenceSessionWrap::GetInputNames(const Napi::CallbackInfo &info) Napi::Value InferenceSessionWrap::GetOutputNames(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); Napi::EscapableHandleScope scope(env); return scope.Escape(CreateNapiArrayFrom(env, outputNames_)); @@ -145,6 +149,7 @@ Napi::Value InferenceSessionWrap::GetOutputNames(const Napi::CallbackInfo &info) Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); ORT_NAPI_THROW_TYPEERROR_IF(info.Length() < 2, env, "Expect argument: inputs(feed) and outputs(fetch)."); ORT_NAPI_THROW_TYPEERROR_IF(!info[0].IsObject() || !info[1].IsObject(), env, "Expect inputs(feed) and outputs(fetch) to be objects."); @@ -209,6 +214,18 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) { } } +Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo &info) { + Napi::Env env = info.Env(); + ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); + + this->defaultRunOptions_.reset(nullptr); + this->session_.reset(nullptr); + + this->disposed_ = true; + return env.Undefined(); +} + Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); Napi::EscapableHandleScope scope(env); diff --git a/js/node/src/inference_session_wrap.h b/js/node/src/inference_session_wrap.h index 9eee45b72dcb1..1e789c4814cd6 100644 --- a/js/node/src/inference_session_wrap.h +++ b/js/node/src/inference_session_wrap.h @@ -55,6 +55,14 @@ class InferenceSessionWrap : public Napi::ObjectWrap { */ Napi::Value Run(const Napi::CallbackInfo &info); + /** + * [sync] dispose the session. + * @param nothing + * @returns nothing + * @throw nothing + */ + Napi::Value Dispose(const Napi::CallbackInfo &info); + // private members // persistent constructor @@ -62,6 +70,7 @@ class InferenceSessionWrap : public Napi::ObjectWrap { // session objects bool initialized_; + bool disposed_; std::unique_ptr session_; std::unique_ptr defaultRunOptions_; diff --git a/js/node/test/ort-schema/protobuf/.gitignore b/js/node/test/ort-schema/protobuf/.gitignore new file mode 100644 index 0000000000000..092bb6c1c9fb4 --- /dev/null +++ b/js/node/test/ort-schema/protobuf/.gitignore @@ -0,0 +1,2 @@ +!onnx.js +!onnx.d.ts diff --git a/js/node/test/ort-schema/protobuf/README.md b/js/node/test/ort-schema/protobuf/README.md new file mode 100644 index 0000000000000..f5f52c602f1ad --- /dev/null +++ b/js/node/test/ort-schema/protobuf/README.md @@ -0,0 +1,21 @@ +# ONNX protobuf + +This directory contains generated protobuf definition for onnx: + +- onnx.js +- onnx.d.ts + +These files are generated from [a fork of onnx-proto](https://github.com/fs-eire/onnx-proto/tree/update-v9). + +The ONNX protobuf uses protobufjs@7.2.4, which depends on long@5.2.3, the version contains 2 bugs: + +- type export does not work with commonjs. described in https://github.com/dcodeIO/long.js/pull/124. added a "postinstall" script to fix. +- in the generated typescript declaration file 'onnx.d.ts', the following line: + ```ts + import Long = require("long"); + ``` + need to be replaced to fix type import error: + ```ts + import Long from "long"; + ``` + this replacement is done and code format is also applied to file 'onnx.d.ts'. diff --git a/js/node/test/ort-schema/protobuf/onnx.d.ts b/js/node/test/ort-schema/protobuf/onnx.d.ts new file mode 100644 index 0000000000000..c60264dca2a8d --- /dev/null +++ b/js/node/test/ort-schema/protobuf/onnx.d.ts @@ -0,0 +1,2627 @@ +import Long from 'long'; +import * as $protobuf from 'protobufjs'; + +/** Namespace onnx. */ +export namespace onnx { + + /** Version enum. */ + enum Version { + _START_VERSION = 0, + IR_VERSION_2017_10_10 = 1, + IR_VERSION_2017_10_30 = 2, + IR_VERSION_2017_11_3 = 3, + IR_VERSION_2019_1_22 = 4, + IR_VERSION_2019_3_18 = 5, + IR_VERSION_2019_9_19 = 6, + IR_VERSION_2020_5_8 = 7, + IR_VERSION_2021_7_30 = 8, + IR_VERSION = 9 + } + + /** Properties of an AttributeProto. */ + interface IAttributeProto { + /** AttributeProto name */ + name?: (string|null); + + /** AttributeProto refAttrName */ + refAttrName?: (string|null); + + /** AttributeProto docString */ + docString?: (string|null); + + /** AttributeProto type */ + type?: (onnx.AttributeProto.AttributeType|null); + + /** AttributeProto f */ + f?: (number|null); + + /** AttributeProto i */ + i?: (number|Long|null); + + /** AttributeProto s */ + s?: (Uint8Array|null); + + /** AttributeProto t */ + t?: (onnx.ITensorProto|null); + + /** AttributeProto g */ + g?: (onnx.IGraphProto|null); + + /** AttributeProto sparseTensor */ + sparseTensor?: (onnx.ISparseTensorProto|null); + + /** AttributeProto tp */ + tp?: (onnx.ITypeProto|null); + + /** AttributeProto floats */ + floats?: (number[]|null); + + /** AttributeProto ints */ + ints?: ((number | Long)[]|null); + + /** AttributeProto strings */ + strings?: (Uint8Array[]|null); + + /** AttributeProto tensors */ + tensors?: (onnx.ITensorProto[]|null); + + /** AttributeProto graphs */ + graphs?: (onnx.IGraphProto[]|null); + + /** AttributeProto sparseTensors */ + sparseTensors?: (onnx.ISparseTensorProto[]|null); + + /** AttributeProto typeProtos */ + typeProtos?: (onnx.ITypeProto[]|null); + } + + /** Represents an AttributeProto. */ + class AttributeProto implements IAttributeProto { + /** + * Constructs a new AttributeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IAttributeProto); + + /** AttributeProto name. */ + public name: string; + + /** AttributeProto refAttrName. */ + public refAttrName: string; + + /** AttributeProto docString. */ + public docString: string; + + /** AttributeProto type. */ + public type: onnx.AttributeProto.AttributeType; + + /** AttributeProto f. */ + public f: number; + + /** AttributeProto i. */ + public i: (number|Long); + + /** AttributeProto s. */ + public s: Uint8Array; + + /** AttributeProto t. */ + public t?: (onnx.ITensorProto|null); + + /** AttributeProto g. */ + public g?: (onnx.IGraphProto|null); + + /** AttributeProto sparseTensor. */ + public sparseTensor?: (onnx.ISparseTensorProto|null); + + /** AttributeProto tp. */ + public tp?: (onnx.ITypeProto|null); + + /** AttributeProto floats. */ + public floats: number[]; + + /** AttributeProto ints. */ + public ints: (number|Long)[]; + + /** AttributeProto strings. */ + public strings: Uint8Array[]; + + /** AttributeProto tensors. */ + public tensors: onnx.ITensorProto[]; + + /** AttributeProto graphs. */ + public graphs: onnx.IGraphProto[]; + + /** AttributeProto sparseTensors. */ + public sparseTensors: onnx.ISparseTensorProto[]; + + /** AttributeProto typeProtos. */ + public typeProtos: onnx.ITypeProto[]; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns AttributeProto instance + */ + public static create(properties?: onnx.IAttributeProto): onnx.AttributeProto; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} + * messages. + * @param message AttributeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link + * onnx.AttributeProto.verify|verify} messages. + * @param message AttributeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.AttributeProto; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.AttributeProto; + + /** + * Verifies an AttributeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns AttributeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.AttributeProto; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @param message AttributeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.AttributeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this AttributeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for AttributeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace AttributeProto { + + /** AttributeType enum. */ + enum AttributeType { + UNDEFINED = 0, + FLOAT = 1, + INT = 2, + STRING = 3, + TENSOR = 4, + GRAPH = 5, + SPARSE_TENSOR = 11, + TYPE_PROTO = 13, + FLOATS = 6, + INTS = 7, + STRINGS = 8, + TENSORS = 9, + GRAPHS = 10, + SPARSE_TENSORS = 12, + TYPE_PROTOS = 14 + } + } + + /** Properties of a ValueInfoProto. */ + interface IValueInfoProto { + /** ValueInfoProto name */ + name?: (string|null); + + /** ValueInfoProto type */ + type?: (onnx.ITypeProto|null); + + /** ValueInfoProto docString */ + docString?: (string|null); + } + + /** Represents a ValueInfoProto. */ + class ValueInfoProto implements IValueInfoProto { + /** + * Constructs a new ValueInfoProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IValueInfoProto); + + /** ValueInfoProto name. */ + public name: string; + + /** ValueInfoProto type. */ + public type?: (onnx.ITypeProto|null); + + /** ValueInfoProto docString. */ + public docString: string; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @param [properties] Properties to set + * @returns ValueInfoProto instance + */ + public static create(properties?: onnx.IValueInfoProto): onnx.ValueInfoProto; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} + * messages. + * @param message ValueInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link + * onnx.ValueInfoProto.verify|verify} messages. + * @param message ValueInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ValueInfoProto; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ValueInfoProto; + + /** + * Verifies a ValueInfoProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns ValueInfoProto + */ + public static fromObject(object: {[k: string]: any}): onnx.ValueInfoProto; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @param message ValueInfoProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.ValueInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this ValueInfoProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for ValueInfoProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a NodeProto. */ + interface INodeProto { + /** NodeProto input */ + input?: (string[]|null); + + /** NodeProto output */ + output?: (string[]|null); + + /** NodeProto name */ + name?: (string|null); + + /** NodeProto opType */ + opType?: (string|null); + + /** NodeProto domain */ + domain?: (string|null); + + /** NodeProto attribute */ + attribute?: (onnx.IAttributeProto[]|null); + + /** NodeProto docString */ + docString?: (string|null); + } + + /** Represents a NodeProto. */ + class NodeProto implements INodeProto { + /** + * Constructs a new NodeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.INodeProto); + + /** NodeProto input. */ + public input: string[]; + + /** NodeProto output. */ + public output: string[]; + + /** NodeProto name. */ + public name: string; + + /** NodeProto opType. */ + public opType: string; + + /** NodeProto domain. */ + public domain: string; + + /** NodeProto attribute. */ + public attribute: onnx.IAttributeProto[]; + + /** NodeProto docString. */ + public docString: string; + + /** + * Creates a new NodeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns NodeProto instance + */ + public static create(properties?: onnx.INodeProto): onnx.NodeProto; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @param message NodeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link + * onnx.NodeProto.verify|verify} messages. + * @param message NodeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.NodeProto; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.NodeProto; + + /** + * Verifies a NodeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns NodeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.NodeProto; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @param message NodeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.NodeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this NodeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for NodeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TrainingInfoProto. */ + interface ITrainingInfoProto { + /** TrainingInfoProto initialization */ + initialization?: (onnx.IGraphProto|null); + + /** TrainingInfoProto algorithm */ + algorithm?: (onnx.IGraphProto|null); + + /** TrainingInfoProto initializationBinding */ + initializationBinding?: (onnx.IStringStringEntryProto[]|null); + + /** TrainingInfoProto updateBinding */ + updateBinding?: (onnx.IStringStringEntryProto[]|null); + } + + /** Represents a TrainingInfoProto. */ + class TrainingInfoProto implements ITrainingInfoProto { + /** + * Constructs a new TrainingInfoProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITrainingInfoProto); + + /** TrainingInfoProto initialization. */ + public initialization?: (onnx.IGraphProto|null); + + /** TrainingInfoProto algorithm. */ + public algorithm?: (onnx.IGraphProto|null); + + /** TrainingInfoProto initializationBinding. */ + public initializationBinding: onnx.IStringStringEntryProto[]; + + /** TrainingInfoProto updateBinding. */ + public updateBinding: onnx.IStringStringEntryProto[]; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TrainingInfoProto instance + */ + public static create(properties?: onnx.ITrainingInfoProto): onnx.TrainingInfoProto; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} + * messages. + * @param message TrainingInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link + * onnx.TrainingInfoProto.verify|verify} messages. + * @param message TrainingInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TrainingInfoProto; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TrainingInfoProto; + + /** + * Verifies a TrainingInfoProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TrainingInfoProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TrainingInfoProto; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @param message TrainingInfoProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TrainingInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TrainingInfoProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TrainingInfoProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a ModelProto. */ + interface IModelProto { + /** ModelProto irVersion */ + irVersion?: (number|Long|null); + + /** ModelProto opsetImport */ + opsetImport?: (onnx.IOperatorSetIdProto[]|null); + + /** ModelProto producerName */ + producerName?: (string|null); + + /** ModelProto producerVersion */ + producerVersion?: (string|null); + + /** ModelProto domain */ + domain?: (string|null); + + /** ModelProto modelVersion */ + modelVersion?: (number|Long|null); + + /** ModelProto docString */ + docString?: (string|null); + + /** ModelProto graph */ + graph?: (onnx.IGraphProto|null); + + /** ModelProto metadataProps */ + metadataProps?: (onnx.IStringStringEntryProto[]|null); + + /** ModelProto trainingInfo */ + trainingInfo?: (onnx.ITrainingInfoProto[]|null); + + /** ModelProto functions */ + functions?: (onnx.IFunctionProto[]|null); + } + + /** Represents a ModelProto. */ + class ModelProto implements IModelProto { + /** + * Constructs a new ModelProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IModelProto); + + /** ModelProto irVersion. */ + public irVersion: (number|Long); + + /** ModelProto opsetImport. */ + public opsetImport: onnx.IOperatorSetIdProto[]; + + /** ModelProto producerName. */ + public producerName: string; + + /** ModelProto producerVersion. */ + public producerVersion: string; + + /** ModelProto domain. */ + public domain: string; + + /** ModelProto modelVersion. */ + public modelVersion: (number|Long); + + /** ModelProto docString. */ + public docString: string; + + /** ModelProto graph. */ + public graph?: (onnx.IGraphProto|null); + + /** ModelProto metadataProps. */ + public metadataProps: onnx.IStringStringEntryProto[]; + + /** ModelProto trainingInfo. */ + public trainingInfo: onnx.ITrainingInfoProto[]; + + /** ModelProto functions. */ + public functions: onnx.IFunctionProto[]; + + /** + * Creates a new ModelProto instance using the specified properties. + * @param [properties] Properties to set + * @returns ModelProto instance + */ + public static create(properties?: onnx.IModelProto): onnx.ModelProto; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @param message ModelProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link + * onnx.ModelProto.verify|verify} messages. + * @param message ModelProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ModelProto; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ModelProto; + + /** + * Verifies a ModelProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns ModelProto + */ + public static fromObject(object: {[k: string]: any}): onnx.ModelProto; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @param message ModelProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.ModelProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this ModelProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for ModelProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a StringStringEntryProto. */ + interface IStringStringEntryProto { + /** StringStringEntryProto key */ + key?: (string|null); + + /** StringStringEntryProto value */ + value?: (string|null); + } + + /** Represents a StringStringEntryProto. */ + class StringStringEntryProto implements IStringStringEntryProto { + /** + * Constructs a new StringStringEntryProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IStringStringEntryProto); + + /** StringStringEntryProto key. */ + public key: string; + + /** StringStringEntryProto value. */ + public value: string; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @param [properties] Properties to set + * @returns StringStringEntryProto instance + */ + public static create(properties?: onnx.IStringStringEntryProto): onnx.StringStringEntryProto; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link + * onnx.StringStringEntryProto.verify|verify} messages. + * @param message StringStringEntryProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link + * onnx.StringStringEntryProto.verify|verify} messages. + * @param message StringStringEntryProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.StringStringEntryProto; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.StringStringEntryProto; + + /** + * Verifies a StringStringEntryProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal + * types. + * @param object Plain object + * @returns StringStringEntryProto + */ + public static fromObject(object: {[k: string]: any}): onnx.StringStringEntryProto; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @param message StringStringEntryProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.StringStringEntryProto, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this StringStringEntryProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for StringStringEntryProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorAnnotation. */ + interface ITensorAnnotation { + /** TensorAnnotation tensorName */ + tensorName?: (string|null); + + /** TensorAnnotation quantParameterTensorNames */ + quantParameterTensorNames?: (onnx.IStringStringEntryProto[]|null); + } + + /** Represents a TensorAnnotation. */ + class TensorAnnotation implements ITensorAnnotation { + /** + * Constructs a new TensorAnnotation. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorAnnotation); + + /** TensorAnnotation tensorName. */ + public tensorName: string; + + /** TensorAnnotation quantParameterTensorNames. */ + public quantParameterTensorNames: onnx.IStringStringEntryProto[]; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorAnnotation instance + */ + public static create(properties?: onnx.ITensorAnnotation): onnx.TensorAnnotation; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} + * messages. + * @param message TensorAnnotation message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link + * onnx.TensorAnnotation.verify|verify} messages. + * @param message TensorAnnotation message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorAnnotation; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorAnnotation; + + /** + * Verifies a TensorAnnotation message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorAnnotation + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorAnnotation; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @param message TensorAnnotation + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorAnnotation, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorAnnotation to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorAnnotation + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a GraphProto. */ + interface IGraphProto { + /** GraphProto node */ + node?: (onnx.INodeProto[]|null); + + /** GraphProto name */ + name?: (string|null); + + /** GraphProto initializer */ + initializer?: (onnx.ITensorProto[]|null); + + /** GraphProto sparseInitializer */ + sparseInitializer?: (onnx.ISparseTensorProto[]|null); + + /** GraphProto docString */ + docString?: (string|null); + + /** GraphProto input */ + input?: (onnx.IValueInfoProto[]|null); + + /** GraphProto output */ + output?: (onnx.IValueInfoProto[]|null); + + /** GraphProto valueInfo */ + valueInfo?: (onnx.IValueInfoProto[]|null); + + /** GraphProto quantizationAnnotation */ + quantizationAnnotation?: (onnx.ITensorAnnotation[]|null); + } + + /** Represents a GraphProto. */ + class GraphProto implements IGraphProto { + /** + * Constructs a new GraphProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IGraphProto); + + /** GraphProto node. */ + public node: onnx.INodeProto[]; + + /** GraphProto name. */ + public name: string; + + /** GraphProto initializer. */ + public initializer: onnx.ITensorProto[]; + + /** GraphProto sparseInitializer. */ + public sparseInitializer: onnx.ISparseTensorProto[]; + + /** GraphProto docString. */ + public docString: string; + + /** GraphProto input. */ + public input: onnx.IValueInfoProto[]; + + /** GraphProto output. */ + public output: onnx.IValueInfoProto[]; + + /** GraphProto valueInfo. */ + public valueInfo: onnx.IValueInfoProto[]; + + /** GraphProto quantizationAnnotation. */ + public quantizationAnnotation: onnx.ITensorAnnotation[]; + + /** + * Creates a new GraphProto instance using the specified properties. + * @param [properties] Properties to set + * @returns GraphProto instance + */ + public static create(properties?: onnx.IGraphProto): onnx.GraphProto; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @param message GraphProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link + * onnx.GraphProto.verify|verify} messages. + * @param message GraphProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.GraphProto; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.GraphProto; + + /** + * Verifies a GraphProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns GraphProto + */ + public static fromObject(object: {[k: string]: any}): onnx.GraphProto; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @param message GraphProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.GraphProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this GraphProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for GraphProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorProto. */ + interface ITensorProto { + /** TensorProto dims */ + dims?: ((number | Long)[]|null); + + /** TensorProto dataType */ + dataType?: (number|null); + + /** TensorProto segment */ + segment?: (onnx.TensorProto.ISegment|null); + + /** TensorProto floatData */ + floatData?: (number[]|null); + + /** TensorProto int32Data */ + int32Data?: (number[]|null); + + /** TensorProto stringData */ + stringData?: (Uint8Array[]|null); + + /** TensorProto int64Data */ + int64Data?: ((number | Long)[]|null); + + /** TensorProto name */ + name?: (string|null); + + /** TensorProto docString */ + docString?: (string|null); + + /** TensorProto rawData */ + rawData?: (Uint8Array|null); + + /** TensorProto externalData */ + externalData?: (onnx.IStringStringEntryProto[]|null); + + /** TensorProto dataLocation */ + dataLocation?: (onnx.TensorProto.DataLocation|null); + + /** TensorProto doubleData */ + doubleData?: (number[]|null); + + /** TensorProto uint64Data */ + uint64Data?: ((number | Long)[]|null); + } + + /** Represents a TensorProto. */ + class TensorProto implements ITensorProto { + /** + * Constructs a new TensorProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorProto); + + /** TensorProto dims. */ + public dims: (number|Long)[]; + + /** TensorProto dataType. */ + public dataType: number; + + /** TensorProto segment. */ + public segment?: (onnx.TensorProto.ISegment|null); + + /** TensorProto floatData. */ + public floatData: number[]; + + /** TensorProto int32Data. */ + public int32Data: number[]; + + /** TensorProto stringData. */ + public stringData: Uint8Array[]; + + /** TensorProto int64Data. */ + public int64Data: (number|Long)[]; + + /** TensorProto name. */ + public name: string; + + /** TensorProto docString. */ + public docString: string; + + /** TensorProto rawData. */ + public rawData: Uint8Array; + + /** TensorProto externalData. */ + public externalData: onnx.IStringStringEntryProto[]; + + /** TensorProto dataLocation. */ + public dataLocation: onnx.TensorProto.DataLocation; + + /** TensorProto doubleData. */ + public doubleData: number[]; + + /** TensorProto uint64Data. */ + public uint64Data: (number|Long)[]; + + /** + * Creates a new TensorProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorProto instance + */ + public static create(properties?: onnx.ITensorProto): onnx.TensorProto; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @param message TensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link + * onnx.TensorProto.verify|verify} messages. + * @param message TensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto; + + /** + * Verifies a TensorProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorProto; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @param message TensorProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TensorProto { + + /** DataType enum. */ + enum DataType { + UNDEFINED = 0, + FLOAT = 1, + UINT8 = 2, + INT8 = 3, + UINT16 = 4, + INT16 = 5, + INT32 = 6, + INT64 = 7, + STRING = 8, + BOOL = 9, + FLOAT16 = 10, + DOUBLE = 11, + UINT32 = 12, + UINT64 = 13, + COMPLEX64 = 14, + COMPLEX128 = 15, + BFLOAT16 = 16, + FLOAT8E4M3FN = 17, + FLOAT8E4M3FNUZ = 18, + FLOAT8E5M2 = 19, + FLOAT8E5M2FNUZ = 20 + } + + /** Properties of a Segment. */ + interface ISegment { + /** Segment begin */ + begin?: (number|Long|null); + + /** Segment end */ + end?: (number|Long|null); + } + + /** Represents a Segment. */ + class Segment implements ISegment { + /** + * Constructs a new Segment. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TensorProto.ISegment); + + /** Segment begin. */ + public begin: (number|Long); + + /** Segment end. */ + public end: (number|Long); + + /** + * Creates a new Segment instance using the specified properties. + * @param [properties] Properties to set + * @returns Segment instance + */ + public static create(properties?: onnx.TensorProto.ISegment): onnx.TensorProto.Segment; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} + * messages. + * @param message Segment message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link + * onnx.TensorProto.Segment.verify|verify} messages. + * @param message Segment message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto.Segment; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto.Segment; + + /** + * Verifies a Segment message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Segment + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorProto.Segment; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @param message Segment + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorProto.Segment, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Segment to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Segment + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** DataLocation enum. */ + enum DataLocation { DEFAULT = 0, EXTERNAL = 1 } + } + + /** Properties of a SparseTensorProto. */ + interface ISparseTensorProto { + /** SparseTensorProto values */ + values?: (onnx.ITensorProto|null); + + /** SparseTensorProto indices */ + indices?: (onnx.ITensorProto|null); + + /** SparseTensorProto dims */ + dims?: ((number | Long)[]|null); + } + + /** Represents a SparseTensorProto. */ + class SparseTensorProto implements ISparseTensorProto { + /** + * Constructs a new SparseTensorProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ISparseTensorProto); + + /** SparseTensorProto values. */ + public values?: (onnx.ITensorProto|null); + + /** SparseTensorProto indices. */ + public indices?: (onnx.ITensorProto|null); + + /** SparseTensorProto dims. */ + public dims: (number|Long)[]; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @param [properties] Properties to set + * @returns SparseTensorProto instance + */ + public static create(properties?: onnx.ISparseTensorProto): onnx.SparseTensorProto; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} + * messages. + * @param message SparseTensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link + * onnx.SparseTensorProto.verify|verify} messages. + * @param message SparseTensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.SparseTensorProto; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.SparseTensorProto; + + /** + * Verifies a SparseTensorProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns SparseTensorProto + */ + public static fromObject(object: {[k: string]: any}): onnx.SparseTensorProto; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @param message SparseTensorProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.SparseTensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this SparseTensorProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for SparseTensorProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorShapeProto. */ + interface ITensorShapeProto { + /** TensorShapeProto dim */ + dim?: (onnx.TensorShapeProto.IDimension[]|null); + } + + /** Represents a TensorShapeProto. */ + class TensorShapeProto implements ITensorShapeProto { + /** + * Constructs a new TensorShapeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorShapeProto); + + /** TensorShapeProto dim. */ + public dim: onnx.TensorShapeProto.IDimension[]; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorShapeProto instance + */ + public static create(properties?: onnx.ITensorShapeProto): onnx.TensorShapeProto; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} + * messages. + * @param message TensorShapeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link + * onnx.TensorShapeProto.verify|verify} messages. + * @param message TensorShapeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto; + + /** + * Verifies a TensorShapeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorShapeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto; + + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @param message TensorShapeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorShapeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorShapeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorShapeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TensorShapeProto { + + /** Properties of a Dimension. */ + interface IDimension { + /** Dimension dimValue */ + dimValue?: (number|Long|null); + + /** Dimension dimParam */ + dimParam?: (string|null); + + /** Dimension denotation */ + denotation?: (string|null); + } + + /** Represents a Dimension. */ + class Dimension implements IDimension { + /** + * Constructs a new Dimension. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TensorShapeProto.IDimension); + + /** Dimension dimValue. */ + public dimValue?: (number|Long|null); + + /** Dimension dimParam. */ + public dimParam?: (string|null); + + /** Dimension denotation. */ + public denotation: string; + + /** Dimension value. */ + public value?: ('dimValue'|'dimParam'); + + /** + * Creates a new Dimension instance using the specified properties. + * @param [properties] Properties to set + * @returns Dimension instance + */ + public static create(properties?: onnx.TensorShapeProto.IDimension): onnx.TensorShapeProto.Dimension; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link + * onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @param message Dimension message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link + * onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @param message Dimension message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): + $protobuf.Writer; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto.Dimension; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto.Dimension; + + /** + * Verifies a Dimension message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Dimension + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto.Dimension; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @param message Dimension + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorShapeProto.Dimension, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Dimension to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Dimension + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + } + + /** Properties of a TypeProto. */ + interface ITypeProto { + /** TypeProto tensorType */ + tensorType?: (onnx.TypeProto.ITensor|null); + + /** TypeProto sequenceType */ + sequenceType?: (onnx.TypeProto.ISequence|null); + + /** TypeProto mapType */ + mapType?: (onnx.TypeProto.IMap|null); + + /** TypeProto optionalType */ + optionalType?: (onnx.TypeProto.IOptional|null); + + /** TypeProto sparseTensorType */ + sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + + /** TypeProto denotation */ + denotation?: (string|null); + } + + /** Represents a TypeProto. */ + class TypeProto implements ITypeProto { + /** + * Constructs a new TypeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITypeProto); + + /** TypeProto tensorType. */ + public tensorType?: (onnx.TypeProto.ITensor|null); + + /** TypeProto sequenceType. */ + public sequenceType?: (onnx.TypeProto.ISequence|null); + + /** TypeProto mapType. */ + public mapType?: (onnx.TypeProto.IMap|null); + + /** TypeProto optionalType. */ + public optionalType?: (onnx.TypeProto.IOptional|null); + + /** TypeProto sparseTensorType. */ + public sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + + /** TypeProto denotation. */ + public denotation: string; + + /** TypeProto value. */ + public value?: ('tensorType'|'sequenceType'|'mapType'|'optionalType'|'sparseTensorType'); + + /** + * Creates a new TypeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TypeProto instance + */ + public static create(properties?: onnx.ITypeProto): onnx.TypeProto; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @param message TypeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link + * onnx.TypeProto.verify|verify} messages. + * @param message TypeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto; + + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto; + + /** + * Verifies a TypeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TypeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @param message TypeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TypeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TypeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TypeProto { + + /** Properties of a Tensor. */ + interface ITensor { + /** Tensor elemType */ + elemType?: (number|null); + + /** Tensor shape */ + shape?: (onnx.ITensorShapeProto|null); + } + + /** Represents a Tensor. */ + class Tensor implements ITensor { + /** + * Constructs a new Tensor. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ITensor); + + /** Tensor elemType. */ + public elemType: number; + + /** Tensor shape. */ + public shape?: (onnx.ITensorShapeProto|null); + + /** + * Creates a new Tensor instance using the specified properties. + * @param [properties] Properties to set + * @returns Tensor instance + */ + public static create(properties?: onnx.TypeProto.ITensor): onnx.TypeProto.Tensor; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @param message Tensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Tensor.verify|verify} messages. + * @param message Tensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Tensor; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Tensor; + + /** + * Verifies a Tensor message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Tensor + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Tensor; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @param message Tensor + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Tensor, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Tensor to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Tensor + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a Sequence. */ + interface ISequence { + /** Sequence elemType */ + elemType?: (onnx.ITypeProto|null); + } + + /** Represents a Sequence. */ + class Sequence implements ISequence { + /** + * Constructs a new Sequence. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ISequence); + + /** Sequence elemType. */ + public elemType?: (onnx.ITypeProto|null); + + /** + * Creates a new Sequence instance using the specified properties. + * @param [properties] Properties to set + * @returns Sequence instance + */ + public static create(properties?: onnx.TypeProto.ISequence): onnx.TypeProto.Sequence; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} + * messages. + * @param message Sequence message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Sequence.verify|verify} messages. + * @param message Sequence message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Sequence; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Sequence; + + /** + * Verifies a Sequence message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Sequence + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Sequence; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @param message Sequence + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Sequence, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Sequence to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Sequence + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a Map. */ + interface IMap { + /** Map keyType */ + keyType?: (number|null); + + /** Map valueType */ + valueType?: (onnx.ITypeProto|null); + } + + /** Represents a Map. */ + class Map implements IMap { + /** + * Constructs a new Map. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.IMap); + + /** Map keyType. */ + public keyType: number; + + /** Map valueType. */ + public valueType?: (onnx.ITypeProto|null); + + /** + * Creates a new Map instance using the specified properties. + * @param [properties] Properties to set + * @returns Map instance + */ + public static create(properties?: onnx.TypeProto.IMap): onnx.TypeProto.Map; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @param message Map message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Map.verify|verify} messages. + * @param message Map message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Map message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Map; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Map; + + /** + * Verifies a Map message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Map + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Map; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @param message Map + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Map, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this Map to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Map + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of an Optional. */ + interface IOptional { + /** Optional elemType */ + elemType?: (onnx.ITypeProto|null); + } + + /** Represents an Optional. */ + class Optional implements IOptional { + /** + * Constructs a new Optional. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.IOptional); + + /** Optional elemType. */ + public elemType?: (onnx.ITypeProto|null); + + /** + * Creates a new Optional instance using the specified properties. + * @param [properties] Properties to set + * @returns Optional instance + */ + public static create(properties?: onnx.TypeProto.IOptional): onnx.TypeProto.Optional; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} + * messages. + * @param message Optional message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Optional.verify|verify} messages. + * @param message Optional message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Optional; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Optional; + + /** + * Verifies an Optional message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Optional + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Optional; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @param message Optional + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Optional, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Optional to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Optional + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a SparseTensor. */ + interface ISparseTensor { + /** SparseTensor elemType */ + elemType?: (number|null); + + /** SparseTensor shape */ + shape?: (onnx.ITensorShapeProto|null); + } + + /** Represents a SparseTensor. */ + class SparseTensor implements ISparseTensor { + /** + * Constructs a new SparseTensor. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ISparseTensor); + + /** SparseTensor elemType. */ + public elemType: number; + + /** SparseTensor shape. */ + public shape?: (onnx.ITensorShapeProto|null); + + /** + * Creates a new SparseTensor instance using the specified properties. + * @param [properties] Properties to set + * @returns SparseTensor instance + */ + public static create(properties?: onnx.TypeProto.ISparseTensor): onnx.TypeProto.SparseTensor; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link + * onnx.TypeProto.SparseTensor.verify|verify} messages. + * @param message SparseTensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link + * onnx.TypeProto.SparseTensor.verify|verify} messages. + * @param message SparseTensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.SparseTensor; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.SparseTensor; + + /** + * Verifies a SparseTensor message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns SparseTensor + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.SparseTensor; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @param message SparseTensor + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.SparseTensor, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this SparseTensor to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for SparseTensor + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + } + + /** Properties of an OperatorSetIdProto. */ + interface IOperatorSetIdProto { + /** OperatorSetIdProto domain */ + domain?: (string|null); + + /** OperatorSetIdProto version */ + version?: (number|Long|null); + } + + /** Represents an OperatorSetIdProto. */ + class OperatorSetIdProto implements IOperatorSetIdProto { + /** + * Constructs a new OperatorSetIdProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IOperatorSetIdProto); + + /** OperatorSetIdProto domain. */ + public domain: string; + + /** OperatorSetIdProto version. */ + public version: (number|Long); + + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @param [properties] Properties to set + * @returns OperatorSetIdProto instance + */ + public static create(properties?: onnx.IOperatorSetIdProto): onnx.OperatorSetIdProto; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link + * onnx.OperatorSetIdProto.verify|verify} messages. + * @param message OperatorSetIdProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link + * onnx.OperatorSetIdProto.verify|verify} messages. + * @param message OperatorSetIdProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.OperatorSetIdProto; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.OperatorSetIdProto; + + /** + * Verifies an OperatorSetIdProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal + * types. + * @param object Plain object + * @returns OperatorSetIdProto + */ + public static fromObject(object: {[k: string]: any}): onnx.OperatorSetIdProto; + + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @param message OperatorSetIdProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.OperatorSetIdProto, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this OperatorSetIdProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for OperatorSetIdProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** OperatorStatus enum. */ + enum OperatorStatus { EXPERIMENTAL = 0, STABLE = 1 } + + /** Properties of a FunctionProto. */ + interface IFunctionProto { + /** FunctionProto name */ + name?: (string|null); + + /** FunctionProto input */ + input?: (string[]|null); + + /** FunctionProto output */ + output?: (string[]|null); + + /** FunctionProto attribute */ + attribute?: (string[]|null); + + /** FunctionProto attributeProto */ + attributeProto?: (onnx.IAttributeProto[]|null); + + /** FunctionProto node */ + node?: (onnx.INodeProto[]|null); + + /** FunctionProto docString */ + docString?: (string|null); + + /** FunctionProto opsetImport */ + opsetImport?: (onnx.IOperatorSetIdProto[]|null); + + /** FunctionProto domain */ + domain?: (string|null); + } + + /** Represents a FunctionProto. */ + class FunctionProto implements IFunctionProto { + /** + * Constructs a new FunctionProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IFunctionProto); + + /** FunctionProto name. */ + public name: string; + + /** FunctionProto input. */ + public input: string[]; + + /** FunctionProto output. */ + public output: string[]; + + /** FunctionProto attribute. */ + public attribute: string[]; + + /** FunctionProto attributeProto. */ + public attributeProto: onnx.IAttributeProto[]; + + /** FunctionProto node. */ + public node: onnx.INodeProto[]; + + /** FunctionProto docString. */ + public docString: string; + + /** FunctionProto opsetImport. */ + public opsetImport: onnx.IOperatorSetIdProto[]; + + /** FunctionProto domain. */ + public domain: string; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @param [properties] Properties to set + * @returns FunctionProto instance + */ + public static create(properties?: onnx.IFunctionProto): onnx.FunctionProto; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} + * messages. + * @param message FunctionProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link + * onnx.FunctionProto.verify|verify} messages. + * @param message FunctionProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.FunctionProto; + + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.FunctionProto; + + /** + * Verifies a FunctionProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns FunctionProto + */ + public static fromObject(object: {[k: string]: any}): onnx.FunctionProto; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @param message FunctionProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.FunctionProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this FunctionProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for FunctionProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } +} diff --git a/js/node/test/ort-schema/protobuf/onnx.js b/js/node/test/ort-schema/protobuf/onnx.js new file mode 100644 index 0000000000000..681855132d4e8 --- /dev/null +++ b/js/node/test/ort-schema/protobuf/onnx.js @@ -0,0 +1,7658 @@ +/*eslint-disable block-scoped-var, id-length, no-control-regex, no-magic-numbers, no-prototype-builtins, no-redeclare, no-shadow, no-var, sort-vars*/ +"use strict"; + +var $protobuf = require("protobufjs/minimal"); + +// Common aliases +var $Reader = $protobuf.Reader, $Writer = $protobuf.Writer, $util = $protobuf.util; + +// Exported root namespace +var $root = $protobuf.roots["default"] || ($protobuf.roots["default"] = {}); + +$root.onnx = (function() { + + /** + * Namespace onnx. + * @exports onnx + * @namespace + */ + var onnx = {}; + + /** + * Version enum. + * @name onnx.Version + * @enum {number} + * @property {number} _START_VERSION=0 _START_VERSION value + * @property {number} IR_VERSION_2017_10_10=1 IR_VERSION_2017_10_10 value + * @property {number} IR_VERSION_2017_10_30=2 IR_VERSION_2017_10_30 value + * @property {number} IR_VERSION_2017_11_3=3 IR_VERSION_2017_11_3 value + * @property {number} IR_VERSION_2019_1_22=4 IR_VERSION_2019_1_22 value + * @property {number} IR_VERSION_2019_3_18=5 IR_VERSION_2019_3_18 value + * @property {number} IR_VERSION_2019_9_19=6 IR_VERSION_2019_9_19 value + * @property {number} IR_VERSION_2020_5_8=7 IR_VERSION_2020_5_8 value + * @property {number} IR_VERSION_2021_7_30=8 IR_VERSION_2021_7_30 value + * @property {number} IR_VERSION=9 IR_VERSION value + */ + onnx.Version = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "_START_VERSION"] = 0; + values[valuesById[1] = "IR_VERSION_2017_10_10"] = 1; + values[valuesById[2] = "IR_VERSION_2017_10_30"] = 2; + values[valuesById[3] = "IR_VERSION_2017_11_3"] = 3; + values[valuesById[4] = "IR_VERSION_2019_1_22"] = 4; + values[valuesById[5] = "IR_VERSION_2019_3_18"] = 5; + values[valuesById[6] = "IR_VERSION_2019_9_19"] = 6; + values[valuesById[7] = "IR_VERSION_2020_5_8"] = 7; + values[valuesById[8] = "IR_VERSION_2021_7_30"] = 8; + values[valuesById[9] = "IR_VERSION"] = 9; + return values; + })(); + + onnx.AttributeProto = (function() { + + /** + * Properties of an AttributeProto. + * @memberof onnx + * @interface IAttributeProto + * @property {string|null} [name] AttributeProto name + * @property {string|null} [refAttrName] AttributeProto refAttrName + * @property {string|null} [docString] AttributeProto docString + * @property {onnx.AttributeProto.AttributeType|null} [type] AttributeProto type + * @property {number|null} [f] AttributeProto f + * @property {number|Long|null} [i] AttributeProto i + * @property {Uint8Array|null} [s] AttributeProto s + * @property {onnx.ITensorProto|null} [t] AttributeProto t + * @property {onnx.IGraphProto|null} [g] AttributeProto g + * @property {onnx.ISparseTensorProto|null} [sparseTensor] AttributeProto sparseTensor + * @property {onnx.ITypeProto|null} [tp] AttributeProto tp + * @property {Array.|null} [floats] AttributeProto floats + * @property {Array.|null} [ints] AttributeProto ints + * @property {Array.|null} [strings] AttributeProto strings + * @property {Array.|null} [tensors] AttributeProto tensors + * @property {Array.|null} [graphs] AttributeProto graphs + * @property {Array.|null} [sparseTensors] AttributeProto sparseTensors + * @property {Array.|null} [typeProtos] AttributeProto typeProtos + */ + + /** + * Constructs a new AttributeProto. + * @memberof onnx + * @classdesc Represents an AttributeProto. + * @implements IAttributeProto + * @constructor + * @param {onnx.IAttributeProto=} [properties] Properties to set + */ + function AttributeProto(properties) { + this.floats = []; + this.ints = []; + this.strings = []; + this.tensors = []; + this.graphs = []; + this.sparseTensors = []; + this.typeProtos = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * AttributeProto name. + * @member {string} name + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.name = ""; + + /** + * AttributeProto refAttrName. + * @member {string} refAttrName + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.refAttrName = ""; + + /** + * AttributeProto docString. + * @member {string} docString + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.docString = ""; + + /** + * AttributeProto type. + * @member {onnx.AttributeProto.AttributeType} type + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.type = 0; + + /** + * AttributeProto f. + * @member {number} f + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.f = 0; + + /** + * AttributeProto i. + * @member {number|Long} i + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.i = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * AttributeProto s. + * @member {Uint8Array} s + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.s = $util.newBuffer([]); + + /** + * AttributeProto t. + * @member {onnx.ITensorProto|null|undefined} t + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.t = null; + + /** + * AttributeProto g. + * @member {onnx.IGraphProto|null|undefined} g + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.g = null; + + /** + * AttributeProto sparseTensor. + * @member {onnx.ISparseTensorProto|null|undefined} sparseTensor + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensor = null; + + /** + * AttributeProto tp. + * @member {onnx.ITypeProto|null|undefined} tp + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tp = null; + + /** + * AttributeProto floats. + * @member {Array.} floats + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.floats = $util.emptyArray; + + /** + * AttributeProto ints. + * @member {Array.} ints + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.ints = $util.emptyArray; + + /** + * AttributeProto strings. + * @member {Array.} strings + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.strings = $util.emptyArray; + + /** + * AttributeProto tensors. + * @member {Array.} tensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tensors = $util.emptyArray; + + /** + * AttributeProto graphs. + * @member {Array.} graphs + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.graphs = $util.emptyArray; + + /** + * AttributeProto sparseTensors. + * @member {Array.} sparseTensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensors = $util.emptyArray; + + /** + * AttributeProto typeProtos. + * @member {Array.} typeProtos + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.typeProtos = $util.emptyArray; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @function create + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto=} [properties] Properties to set + * @returns {onnx.AttributeProto} AttributeProto instance + */ + AttributeProto.create = function create(properties) { + return new AttributeProto(properties); + }; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encode + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.f != null && Object.hasOwnProperty.call(message, "f")) + writer.uint32(/* id 2, wireType 5 =*/21).float(message.f); + if (message.i != null && Object.hasOwnProperty.call(message, "i")) + writer.uint32(/* id 3, wireType 0 =*/24).int64(message.i); + if (message.s != null && Object.hasOwnProperty.call(message, "s")) + writer.uint32(/* id 4, wireType 2 =*/34).bytes(message.s); + if (message.t != null && Object.hasOwnProperty.call(message, "t")) + $root.onnx.TensorProto.encode(message.t, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.g != null && Object.hasOwnProperty.call(message, "g")) + $root.onnx.GraphProto.encode(message.g, writer.uint32(/* id 6, wireType 2 =*/50).fork()).ldelim(); + if (message.floats != null && message.floats.length) { + writer.uint32(/* id 7, wireType 2 =*/58).fork(); + for (var i = 0; i < message.floats.length; ++i) + writer.float(message.floats[i]); + writer.ldelim(); + } + if (message.ints != null && message.ints.length) { + writer.uint32(/* id 8, wireType 2 =*/66).fork(); + for (var i = 0; i < message.ints.length; ++i) + writer.int64(message.ints[i]); + writer.ldelim(); + } + if (message.strings != null && message.strings.length) + for (var i = 0; i < message.strings.length; ++i) + writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.strings[i]); + if (message.tensors != null && message.tensors.length) + for (var i = 0; i < message.tensors.length; ++i) + $root.onnx.TensorProto.encode(message.tensors[i], writer.uint32(/* id 10, wireType 2 =*/82).fork()).ldelim(); + if (message.graphs != null && message.graphs.length) + for (var i = 0; i < message.graphs.length; ++i) + $root.onnx.GraphProto.encode(message.graphs[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 13, wireType 2 =*/106).string(message.docString); + if (message.tp != null && Object.hasOwnProperty.call(message, "tp")) + $root.onnx.TypeProto.encode(message.tp, writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.typeProtos != null && message.typeProtos.length) + for (var i = 0; i < message.typeProtos.length; ++i) + $root.onnx.TypeProto.encode(message.typeProtos[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + writer.uint32(/* id 20, wireType 0 =*/160).int32(message.type); + if (message.refAttrName != null && Object.hasOwnProperty.call(message, "refAttrName")) + writer.uint32(/* id 21, wireType 2 =*/170).string(message.refAttrName); + if (message.sparseTensor != null && Object.hasOwnProperty.call(message, "sparseTensor")) + $root.onnx.SparseTensorProto.encode(message.sparseTensor, writer.uint32(/* id 22, wireType 2 =*/178).fork()).ldelim(); + if (message.sparseTensors != null && message.sparseTensors.length) + for (var i = 0; i < message.sparseTensors.length; ++i) + $root.onnx.SparseTensorProto.encode(message.sparseTensors[i], writer.uint32(/* id 23, wireType 2 =*/186).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.AttributeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 21: { + message.refAttrName = reader.string(); + break; + } + case 13: { + message.docString = reader.string(); + break; + } + case 20: { + message.type = reader.int32(); + break; + } + case 2: { + message.f = reader.float(); + break; + } + case 3: { + message.i = reader.int64(); + break; + } + case 4: { + message.s = reader.bytes(); + break; + } + case 5: { + message.t = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 6: { + message.g = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 22: { + message.sparseTensor = $root.onnx.SparseTensorProto.decode(reader, reader.uint32()); + break; + } + case 14: { + message.tp = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 7: { + if (!(message.floats && message.floats.length)) + message.floats = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.floats.push(reader.float()); + } else + message.floats.push(reader.float()); + break; + } + case 8: { + if (!(message.ints && message.ints.length)) + message.ints = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.ints.push(reader.int64()); + } else + message.ints.push(reader.int64()); + break; + } + case 9: { + if (!(message.strings && message.strings.length)) + message.strings = []; + message.strings.push(reader.bytes()); + break; + } + case 10: { + if (!(message.tensors && message.tensors.length)) + message.tensors = []; + message.tensors.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 11: { + if (!(message.graphs && message.graphs.length)) + message.graphs = []; + message.graphs.push($root.onnx.GraphProto.decode(reader, reader.uint32())); + break; + } + case 23: { + if (!(message.sparseTensors && message.sparseTensors.length)) + message.sparseTensors = []; + message.sparseTensors.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.typeProtos && message.typeProtos.length)) + message.typeProtos = []; + message.typeProtos.push($root.onnx.TypeProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an AttributeProto message. + * @function verify + * @memberof onnx.AttributeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + AttributeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + if (!$util.isString(message.refAttrName)) + return "refAttrName: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.type != null && message.hasOwnProperty("type")) + switch (message.type) { + default: + return "type: enum value expected"; + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 11: + case 13: + case 6: + case 7: + case 8: + case 9: + case 10: + case 12: + case 14: + break; + } + if (message.f != null && message.hasOwnProperty("f")) + if (typeof message.f !== "number") + return "f: number expected"; + if (message.i != null && message.hasOwnProperty("i")) + if (!$util.isInteger(message.i) && !(message.i && $util.isInteger(message.i.low) && $util.isInteger(message.i.high))) + return "i: integer|Long expected"; + if (message.s != null && message.hasOwnProperty("s")) + if (!(message.s && typeof message.s.length === "number" || $util.isString(message.s))) + return "s: buffer expected"; + if (message.t != null && message.hasOwnProperty("t")) { + var error = $root.onnx.TensorProto.verify(message.t); + if (error) + return "t." + error; + } + if (message.g != null && message.hasOwnProperty("g")) { + var error = $root.onnx.GraphProto.verify(message.g); + if (error) + return "g." + error; + } + if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensor); + if (error) + return "sparseTensor." + error; + } + if (message.tp != null && message.hasOwnProperty("tp")) { + var error = $root.onnx.TypeProto.verify(message.tp); + if (error) + return "tp." + error; + } + if (message.floats != null && message.hasOwnProperty("floats")) { + if (!Array.isArray(message.floats)) + return "floats: array expected"; + for (var i = 0; i < message.floats.length; ++i) + if (typeof message.floats[i] !== "number") + return "floats: number[] expected"; + } + if (message.ints != null && message.hasOwnProperty("ints")) { + if (!Array.isArray(message.ints)) + return "ints: array expected"; + for (var i = 0; i < message.ints.length; ++i) + if (!$util.isInteger(message.ints[i]) && !(message.ints[i] && $util.isInteger(message.ints[i].low) && $util.isInteger(message.ints[i].high))) + return "ints: integer|Long[] expected"; + } + if (message.strings != null && message.hasOwnProperty("strings")) { + if (!Array.isArray(message.strings)) + return "strings: array expected"; + for (var i = 0; i < message.strings.length; ++i) + if (!(message.strings[i] && typeof message.strings[i].length === "number" || $util.isString(message.strings[i]))) + return "strings: buffer[] expected"; + } + if (message.tensors != null && message.hasOwnProperty("tensors")) { + if (!Array.isArray(message.tensors)) + return "tensors: array expected"; + for (var i = 0; i < message.tensors.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.tensors[i]); + if (error) + return "tensors." + error; + } + } + if (message.graphs != null && message.hasOwnProperty("graphs")) { + if (!Array.isArray(message.graphs)) + return "graphs: array expected"; + for (var i = 0; i < message.graphs.length; ++i) { + var error = $root.onnx.GraphProto.verify(message.graphs[i]); + if (error) + return "graphs." + error; + } + } + if (message.sparseTensors != null && message.hasOwnProperty("sparseTensors")) { + if (!Array.isArray(message.sparseTensors)) + return "sparseTensors: array expected"; + for (var i = 0; i < message.sparseTensors.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensors[i]); + if (error) + return "sparseTensors." + error; + } + } + if (message.typeProtos != null && message.hasOwnProperty("typeProtos")) { + if (!Array.isArray(message.typeProtos)) + return "typeProtos: array expected"; + for (var i = 0; i < message.typeProtos.length; ++i) { + var error = $root.onnx.TypeProto.verify(message.typeProtos[i]); + if (error) + return "typeProtos." + error; + } + } + return null; + }; + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.AttributeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.AttributeProto} AttributeProto + */ + AttributeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.AttributeProto) + return object; + var message = new $root.onnx.AttributeProto(); + if (object.name != null) + message.name = String(object.name); + if (object.refAttrName != null) + message.refAttrName = String(object.refAttrName); + if (object.docString != null) + message.docString = String(object.docString); + switch (object.type) { + default: + if (typeof object.type === "number") { + message.type = object.type; + break; + } + break; + case "UNDEFINED": + case 0: + message.type = 0; + break; + case "FLOAT": + case 1: + message.type = 1; + break; + case "INT": + case 2: + message.type = 2; + break; + case "STRING": + case 3: + message.type = 3; + break; + case "TENSOR": + case 4: + message.type = 4; + break; + case "GRAPH": + case 5: + message.type = 5; + break; + case "SPARSE_TENSOR": + case 11: + message.type = 11; + break; + case "TYPE_PROTO": + case 13: + message.type = 13; + break; + case "FLOATS": + case 6: + message.type = 6; + break; + case "INTS": + case 7: + message.type = 7; + break; + case "STRINGS": + case 8: + message.type = 8; + break; + case "TENSORS": + case 9: + message.type = 9; + break; + case "GRAPHS": + case 10: + message.type = 10; + break; + case "SPARSE_TENSORS": + case 12: + message.type = 12; + break; + case "TYPE_PROTOS": + case 14: + message.type = 14; + break; + } + if (object.f != null) + message.f = Number(object.f); + if (object.i != null) + if ($util.Long) + (message.i = $util.Long.fromValue(object.i)).unsigned = false; + else if (typeof object.i === "string") + message.i = parseInt(object.i, 10); + else if (typeof object.i === "number") + message.i = object.i; + else if (typeof object.i === "object") + message.i = new $util.LongBits(object.i.low >>> 0, object.i.high >>> 0).toNumber(); + if (object.s != null) + if (typeof object.s === "string") + $util.base64.decode(object.s, message.s = $util.newBuffer($util.base64.length(object.s)), 0); + else if (object.s.length >= 0) + message.s = object.s; + if (object.t != null) { + if (typeof object.t !== "object") + throw TypeError(".onnx.AttributeProto.t: object expected"); + message.t = $root.onnx.TensorProto.fromObject(object.t); + } + if (object.g != null) { + if (typeof object.g !== "object") + throw TypeError(".onnx.AttributeProto.g: object expected"); + message.g = $root.onnx.GraphProto.fromObject(object.g); + } + if (object.sparseTensor != null) { + if (typeof object.sparseTensor !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensor: object expected"); + message.sparseTensor = $root.onnx.SparseTensorProto.fromObject(object.sparseTensor); + } + if (object.tp != null) { + if (typeof object.tp !== "object") + throw TypeError(".onnx.AttributeProto.tp: object expected"); + message.tp = $root.onnx.TypeProto.fromObject(object.tp); + } + if (object.floats) { + if (!Array.isArray(object.floats)) + throw TypeError(".onnx.AttributeProto.floats: array expected"); + message.floats = []; + for (var i = 0; i < object.floats.length; ++i) + message.floats[i] = Number(object.floats[i]); + } + if (object.ints) { + if (!Array.isArray(object.ints)) + throw TypeError(".onnx.AttributeProto.ints: array expected"); + message.ints = []; + for (var i = 0; i < object.ints.length; ++i) + if ($util.Long) + (message.ints[i] = $util.Long.fromValue(object.ints[i])).unsigned = false; + else if (typeof object.ints[i] === "string") + message.ints[i] = parseInt(object.ints[i], 10); + else if (typeof object.ints[i] === "number") + message.ints[i] = object.ints[i]; + else if (typeof object.ints[i] === "object") + message.ints[i] = new $util.LongBits(object.ints[i].low >>> 0, object.ints[i].high >>> 0).toNumber(); + } + if (object.strings) { + if (!Array.isArray(object.strings)) + throw TypeError(".onnx.AttributeProto.strings: array expected"); + message.strings = []; + for (var i = 0; i < object.strings.length; ++i) + if (typeof object.strings[i] === "string") + $util.base64.decode(object.strings[i], message.strings[i] = $util.newBuffer($util.base64.length(object.strings[i])), 0); + else if (object.strings[i].length >= 0) + message.strings[i] = object.strings[i]; + } + if (object.tensors) { + if (!Array.isArray(object.tensors)) + throw TypeError(".onnx.AttributeProto.tensors: array expected"); + message.tensors = []; + for (var i = 0; i < object.tensors.length; ++i) { + if (typeof object.tensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.tensors: object expected"); + message.tensors[i] = $root.onnx.TensorProto.fromObject(object.tensors[i]); + } + } + if (object.graphs) { + if (!Array.isArray(object.graphs)) + throw TypeError(".onnx.AttributeProto.graphs: array expected"); + message.graphs = []; + for (var i = 0; i < object.graphs.length; ++i) { + if (typeof object.graphs[i] !== "object") + throw TypeError(".onnx.AttributeProto.graphs: object expected"); + message.graphs[i] = $root.onnx.GraphProto.fromObject(object.graphs[i]); + } + } + if (object.sparseTensors) { + if (!Array.isArray(object.sparseTensors)) + throw TypeError(".onnx.AttributeProto.sparseTensors: array expected"); + message.sparseTensors = []; + for (var i = 0; i < object.sparseTensors.length; ++i) { + if (typeof object.sparseTensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensors: object expected"); + message.sparseTensors[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseTensors[i]); + } + } + if (object.typeProtos) { + if (!Array.isArray(object.typeProtos)) + throw TypeError(".onnx.AttributeProto.typeProtos: array expected"); + message.typeProtos = []; + for (var i = 0; i < object.typeProtos.length; ++i) { + if (typeof object.typeProtos[i] !== "object") + throw TypeError(".onnx.AttributeProto.typeProtos: object expected"); + message.typeProtos[i] = $root.onnx.TypeProto.fromObject(object.typeProtos[i]); + } + } + return message; + }; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.AttributeProto + * @static + * @param {onnx.AttributeProto} message AttributeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + AttributeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.floats = []; + object.ints = []; + object.strings = []; + object.tensors = []; + object.graphs = []; + object.typeProtos = []; + object.sparseTensors = []; + } + if (options.defaults) { + object.name = ""; + object.f = 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.i = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.i = options.longs === String ? "0" : 0; + if (options.bytes === String) + object.s = ""; + else { + object.s = []; + if (options.bytes !== Array) + object.s = $util.newBuffer(object.s); + } + object.t = null; + object.g = null; + object.docString = ""; + object.tp = null; + object.type = options.enums === String ? "UNDEFINED" : 0; + object.refAttrName = ""; + object.sparseTensor = null; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.f != null && message.hasOwnProperty("f")) + object.f = options.json && !isFinite(message.f) ? String(message.f) : message.f; + if (message.i != null && message.hasOwnProperty("i")) + if (typeof message.i === "number") + object.i = options.longs === String ? String(message.i) : message.i; + else + object.i = options.longs === String ? $util.Long.prototype.toString.call(message.i) : options.longs === Number ? new $util.LongBits(message.i.low >>> 0, message.i.high >>> 0).toNumber() : message.i; + if (message.s != null && message.hasOwnProperty("s")) + object.s = options.bytes === String ? $util.base64.encode(message.s, 0, message.s.length) : options.bytes === Array ? Array.prototype.slice.call(message.s) : message.s; + if (message.t != null && message.hasOwnProperty("t")) + object.t = $root.onnx.TensorProto.toObject(message.t, options); + if (message.g != null && message.hasOwnProperty("g")) + object.g = $root.onnx.GraphProto.toObject(message.g, options); + if (message.floats && message.floats.length) { + object.floats = []; + for (var j = 0; j < message.floats.length; ++j) + object.floats[j] = options.json && !isFinite(message.floats[j]) ? String(message.floats[j]) : message.floats[j]; + } + if (message.ints && message.ints.length) { + object.ints = []; + for (var j = 0; j < message.ints.length; ++j) + if (typeof message.ints[j] === "number") + object.ints[j] = options.longs === String ? String(message.ints[j]) : message.ints[j]; + else + object.ints[j] = options.longs === String ? $util.Long.prototype.toString.call(message.ints[j]) : options.longs === Number ? new $util.LongBits(message.ints[j].low >>> 0, message.ints[j].high >>> 0).toNumber() : message.ints[j]; + } + if (message.strings && message.strings.length) { + object.strings = []; + for (var j = 0; j < message.strings.length; ++j) + object.strings[j] = options.bytes === String ? $util.base64.encode(message.strings[j], 0, message.strings[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.strings[j]) : message.strings[j]; + } + if (message.tensors && message.tensors.length) { + object.tensors = []; + for (var j = 0; j < message.tensors.length; ++j) + object.tensors[j] = $root.onnx.TensorProto.toObject(message.tensors[j], options); + } + if (message.graphs && message.graphs.length) { + object.graphs = []; + for (var j = 0; j < message.graphs.length; ++j) + object.graphs[j] = $root.onnx.GraphProto.toObject(message.graphs[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.tp != null && message.hasOwnProperty("tp")) + object.tp = $root.onnx.TypeProto.toObject(message.tp, options); + if (message.typeProtos && message.typeProtos.length) { + object.typeProtos = []; + for (var j = 0; j < message.typeProtos.length; ++j) + object.typeProtos[j] = $root.onnx.TypeProto.toObject(message.typeProtos[j], options); + } + if (message.type != null && message.hasOwnProperty("type")) + object.type = options.enums === String ? $root.onnx.AttributeProto.AttributeType[message.type] === undefined ? message.type : $root.onnx.AttributeProto.AttributeType[message.type] : message.type; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + object.refAttrName = message.refAttrName; + if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) + object.sparseTensor = $root.onnx.SparseTensorProto.toObject(message.sparseTensor, options); + if (message.sparseTensors && message.sparseTensors.length) { + object.sparseTensors = []; + for (var j = 0; j < message.sparseTensors.length; ++j) + object.sparseTensors[j] = $root.onnx.SparseTensorProto.toObject(message.sparseTensors[j], options); + } + return object; + }; + + /** + * Converts this AttributeProto to JSON. + * @function toJSON + * @memberof onnx.AttributeProto + * @instance + * @returns {Object.} JSON object + */ + AttributeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for AttributeProto + * @function getTypeUrl + * @memberof onnx.AttributeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + AttributeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.AttributeProto"; + }; + + /** + * AttributeType enum. + * @name onnx.AttributeProto.AttributeType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} INT=2 INT value + * @property {number} STRING=3 STRING value + * @property {number} TENSOR=4 TENSOR value + * @property {number} GRAPH=5 GRAPH value + * @property {number} SPARSE_TENSOR=11 SPARSE_TENSOR value + * @property {number} TYPE_PROTO=13 TYPE_PROTO value + * @property {number} FLOATS=6 FLOATS value + * @property {number} INTS=7 INTS value + * @property {number} STRINGS=8 STRINGS value + * @property {number} TENSORS=9 TENSORS value + * @property {number} GRAPHS=10 GRAPHS value + * @property {number} SPARSE_TENSORS=12 SPARSE_TENSORS value + * @property {number} TYPE_PROTOS=14 TYPE_PROTOS value + */ + AttributeProto.AttributeType = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "UNDEFINED"] = 0; + values[valuesById[1] = "FLOAT"] = 1; + values[valuesById[2] = "INT"] = 2; + values[valuesById[3] = "STRING"] = 3; + values[valuesById[4] = "TENSOR"] = 4; + values[valuesById[5] = "GRAPH"] = 5; + values[valuesById[11] = "SPARSE_TENSOR"] = 11; + values[valuesById[13] = "TYPE_PROTO"] = 13; + values[valuesById[6] = "FLOATS"] = 6; + values[valuesById[7] = "INTS"] = 7; + values[valuesById[8] = "STRINGS"] = 8; + values[valuesById[9] = "TENSORS"] = 9; + values[valuesById[10] = "GRAPHS"] = 10; + values[valuesById[12] = "SPARSE_TENSORS"] = 12; + values[valuesById[14] = "TYPE_PROTOS"] = 14; + return values; + })(); + + return AttributeProto; + })(); + + onnx.ValueInfoProto = (function() { + + /** + * Properties of a ValueInfoProto. + * @memberof onnx + * @interface IValueInfoProto + * @property {string|null} [name] ValueInfoProto name + * @property {onnx.ITypeProto|null} [type] ValueInfoProto type + * @property {string|null} [docString] ValueInfoProto docString + */ + + /** + * Constructs a new ValueInfoProto. + * @memberof onnx + * @classdesc Represents a ValueInfoProto. + * @implements IValueInfoProto + * @constructor + * @param {onnx.IValueInfoProto=} [properties] Properties to set + */ + function ValueInfoProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * ValueInfoProto name. + * @member {string} name + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.name = ""; + + /** + * ValueInfoProto type. + * @member {onnx.ITypeProto|null|undefined} type + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.type = null; + + /** + * ValueInfoProto docString. + * @member {string} docString + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.docString = ""; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @function create + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto=} [properties] Properties to set + * @returns {onnx.ValueInfoProto} ValueInfoProto instance + */ + ValueInfoProto.create = function create(properties) { + return new ValueInfoProto(properties); + }; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + $root.onnx.TypeProto.encode(message.type, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.docString); + return writer; + }; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ValueInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 2: { + message.type = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 3: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ValueInfoProto message. + * @function verify + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ValueInfoProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.type != null && message.hasOwnProperty("type")) { + var error = $root.onnx.TypeProto.verify(message.type); + if (error) + return "type." + error; + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + return null; + }; + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ValueInfoProto} ValueInfoProto + */ + ValueInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ValueInfoProto) + return object; + var message = new $root.onnx.ValueInfoProto(); + if (object.name != null) + message.name = String(object.name); + if (object.type != null) { + if (typeof object.type !== "object") + throw TypeError(".onnx.ValueInfoProto.type: object expected"); + message.type = $root.onnx.TypeProto.fromObject(object.type); + } + if (object.docString != null) + message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.ValueInfoProto} message ValueInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ValueInfoProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.name = ""; + object.type = null; + object.docString = ""; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.type != null && message.hasOwnProperty("type")) + object.type = $root.onnx.TypeProto.toObject(message.type, options); + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + return object; + }; + + /** + * Converts this ValueInfoProto to JSON. + * @function toJSON + * @memberof onnx.ValueInfoProto + * @instance + * @returns {Object.} JSON object + */ + ValueInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ValueInfoProto + * @function getTypeUrl + * @memberof onnx.ValueInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ValueInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.ValueInfoProto"; + }; + + return ValueInfoProto; + })(); + + onnx.NodeProto = (function() { + + /** + * Properties of a NodeProto. + * @memberof onnx + * @interface INodeProto + * @property {Array.|null} [input] NodeProto input + * @property {Array.|null} [output] NodeProto output + * @property {string|null} [name] NodeProto name + * @property {string|null} [opType] NodeProto opType + * @property {string|null} [domain] NodeProto domain + * @property {Array.|null} [attribute] NodeProto attribute + * @property {string|null} [docString] NodeProto docString + */ + + /** + * Constructs a new NodeProto. + * @memberof onnx + * @classdesc Represents a NodeProto. + * @implements INodeProto + * @constructor + * @param {onnx.INodeProto=} [properties] Properties to set + */ + function NodeProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * NodeProto input. + * @member {Array.} input + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.input = $util.emptyArray; + + /** + * NodeProto output. + * @member {Array.} output + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.output = $util.emptyArray; + + /** + * NodeProto name. + * @member {string} name + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.name = ""; + + /** + * NodeProto opType. + * @member {string} opType + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.opType = ""; + + /** + * NodeProto domain. + * @member {string} domain + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.domain = ""; + + /** + * NodeProto attribute. + * @member {Array.} attribute + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.attribute = $util.emptyArray; + + /** + * NodeProto docString. + * @member {string} docString + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.docString = ""; + + /** + * Creates a new NodeProto instance using the specified properties. + * @function create + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto=} [properties] Properties to set + * @returns {onnx.NodeProto} NodeProto instance + */ + NodeProto.create = function create(properties) { + return new NodeProto(properties); + }; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encode + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.output[i]); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.name); + if (message.opType != null && Object.hasOwnProperty.call(message, "opType")) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.opType); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + $root.onnx.AttributeProto.encode(message.attribute[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 7, wireType 2 =*/58).string(message.domain); + return writer; + }; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.NodeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push(reader.string()); + break; + } + case 2: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push(reader.string()); + break; + } + case 3: { + message.name = reader.string(); + break; + } + case 4: { + message.opType = reader.string(); + break; + } + case 7: { + message.domain = reader.string(); + break; + } + case 5: { + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a NodeProto message. + * @function verify + * @memberof onnx.NodeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + NodeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) + return "input: string[] expected"; + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.opType != null && message.hasOwnProperty("opType")) + if (!$util.isString(message.opType)) + return "opType: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; + for (var i = 0; i < message.attribute.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attribute[i]); + if (error) + return "attribute." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + return null; + }; + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.NodeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.NodeProto} NodeProto + */ + NodeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.NodeProto) + return object; + var message = new $root.onnx.NodeProto(); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.NodeProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.NodeProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); + } + if (object.name != null) + message.name = String(object.name); + if (object.opType != null) + message.opType = String(object.opType); + if (object.domain != null) + message.domain = String(object.domain); + if (object.attribute) { + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.NodeProto.attribute: array expected"); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) { + if (typeof object.attribute[i] !== "object") + throw TypeError(".onnx.NodeProto.attribute: object expected"); + message.attribute[i] = $root.onnx.AttributeProto.fromObject(object.attribute[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.NodeProto + * @static + * @param {onnx.NodeProto} message NodeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + NodeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + } + if (options.defaults) { + object.name = ""; + object.opType = ""; + object.docString = ""; + object.domain = ""; + } + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.opType != null && message.hasOwnProperty("opType")) + object.opType = message.opType; + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = $root.onnx.AttributeProto.toObject(message.attribute[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + return object; + }; + + /** + * Converts this NodeProto to JSON. + * @function toJSON + * @memberof onnx.NodeProto + * @instance + * @returns {Object.} JSON object + */ + NodeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for NodeProto + * @function getTypeUrl + * @memberof onnx.NodeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + NodeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.NodeProto"; + }; + + return NodeProto; + })(); + + onnx.TrainingInfoProto = (function() { + + /** + * Properties of a TrainingInfoProto. + * @memberof onnx + * @interface ITrainingInfoProto + * @property {onnx.IGraphProto|null} [initialization] TrainingInfoProto initialization + * @property {onnx.IGraphProto|null} [algorithm] TrainingInfoProto algorithm + * @property {Array.|null} [initializationBinding] TrainingInfoProto initializationBinding + * @property {Array.|null} [updateBinding] TrainingInfoProto updateBinding + */ + + /** + * Constructs a new TrainingInfoProto. + * @memberof onnx + * @classdesc Represents a TrainingInfoProto. + * @implements ITrainingInfoProto + * @constructor + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + */ + function TrainingInfoProto(properties) { + this.initializationBinding = []; + this.updateBinding = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TrainingInfoProto initialization. + * @member {onnx.IGraphProto|null|undefined} initialization + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initialization = null; + + /** + * TrainingInfoProto algorithm. + * @member {onnx.IGraphProto|null|undefined} algorithm + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.algorithm = null; + + /** + * TrainingInfoProto initializationBinding. + * @member {Array.} initializationBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initializationBinding = $util.emptyArray; + + /** + * TrainingInfoProto updateBinding. + * @member {Array.} updateBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.updateBinding = $util.emptyArray; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @function create + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + * @returns {onnx.TrainingInfoProto} TrainingInfoProto instance + */ + TrainingInfoProto.create = function create(properties) { + return new TrainingInfoProto(properties); + }; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.initialization != null && Object.hasOwnProperty.call(message, "initialization")) + $root.onnx.GraphProto.encode(message.initialization, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.algorithm != null && Object.hasOwnProperty.call(message, "algorithm")) + $root.onnx.GraphProto.encode(message.algorithm, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.initializationBinding != null && message.initializationBinding.length) + for (var i = 0; i < message.initializationBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.initializationBinding[i], writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); + if (message.updateBinding != null && message.updateBinding.length) + for (var i = 0; i < message.updateBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.updateBinding[i], writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TrainingInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.initialization = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.algorithm = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.initializationBinding && message.initializationBinding.length)) + message.initializationBinding = []; + message.initializationBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 4: { + if (!(message.updateBinding && message.updateBinding.length)) + message.updateBinding = []; + message.updateBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TrainingInfoProto message. + * @function verify + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TrainingInfoProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.initialization != null && message.hasOwnProperty("initialization")) { + var error = $root.onnx.GraphProto.verify(message.initialization); + if (error) + return "initialization." + error; + } + if (message.algorithm != null && message.hasOwnProperty("algorithm")) { + var error = $root.onnx.GraphProto.verify(message.algorithm); + if (error) + return "algorithm." + error; + } + if (message.initializationBinding != null && message.hasOwnProperty("initializationBinding")) { + if (!Array.isArray(message.initializationBinding)) + return "initializationBinding: array expected"; + for (var i = 0; i < message.initializationBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.initializationBinding[i]); + if (error) + return "initializationBinding." + error; + } + } + if (message.updateBinding != null && message.hasOwnProperty("updateBinding")) { + if (!Array.isArray(message.updateBinding)) + return "updateBinding: array expected"; + for (var i = 0; i < message.updateBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.updateBinding[i]); + if (error) + return "updateBinding." + error; + } + } + return null; + }; + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + */ + TrainingInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TrainingInfoProto) + return object; + var message = new $root.onnx.TrainingInfoProto(); + if (object.initialization != null) { + if (typeof object.initialization !== "object") + throw TypeError(".onnx.TrainingInfoProto.initialization: object expected"); + message.initialization = $root.onnx.GraphProto.fromObject(object.initialization); + } + if (object.algorithm != null) { + if (typeof object.algorithm !== "object") + throw TypeError(".onnx.TrainingInfoProto.algorithm: object expected"); + message.algorithm = $root.onnx.GraphProto.fromObject(object.algorithm); + } + if (object.initializationBinding) { + if (!Array.isArray(object.initializationBinding)) + throw TypeError(".onnx.TrainingInfoProto.initializationBinding: array expected"); + message.initializationBinding = []; + for (var i = 0; i < object.initializationBinding.length; ++i) { + if (typeof object.initializationBinding[i] !== "object") + throw TypeError(".onnx.TrainingInfoProto.initializationBinding: object expected"); + message.initializationBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.initializationBinding[i]); + } + } + if (object.updateBinding) { + if (!Array.isArray(object.updateBinding)) + throw TypeError(".onnx.TrainingInfoProto.updateBinding: array expected"); + message.updateBinding = []; + for (var i = 0; i < object.updateBinding.length; ++i) { + if (typeof object.updateBinding[i] !== "object") + throw TypeError(".onnx.TrainingInfoProto.updateBinding: object expected"); + message.updateBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.updateBinding[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.TrainingInfoProto} message TrainingInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TrainingInfoProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.initializationBinding = []; + object.updateBinding = []; + } + if (options.defaults) { + object.initialization = null; + object.algorithm = null; + } + if (message.initialization != null && message.hasOwnProperty("initialization")) + object.initialization = $root.onnx.GraphProto.toObject(message.initialization, options); + if (message.algorithm != null && message.hasOwnProperty("algorithm")) + object.algorithm = $root.onnx.GraphProto.toObject(message.algorithm, options); + if (message.initializationBinding && message.initializationBinding.length) { + object.initializationBinding = []; + for (var j = 0; j < message.initializationBinding.length; ++j) + object.initializationBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.initializationBinding[j], options); + } + if (message.updateBinding && message.updateBinding.length) { + object.updateBinding = []; + for (var j = 0; j < message.updateBinding.length; ++j) + object.updateBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.updateBinding[j], options); + } + return object; + }; + + /** + * Converts this TrainingInfoProto to JSON. + * @function toJSON + * @memberof onnx.TrainingInfoProto + * @instance + * @returns {Object.} JSON object + */ + TrainingInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TrainingInfoProto + * @function getTypeUrl + * @memberof onnx.TrainingInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TrainingInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TrainingInfoProto"; + }; + + return TrainingInfoProto; + })(); + + onnx.ModelProto = (function() { + + /** + * Properties of a ModelProto. + * @memberof onnx + * @interface IModelProto + * @property {number|Long|null} [irVersion] ModelProto irVersion + * @property {Array.|null} [opsetImport] ModelProto opsetImport + * @property {string|null} [producerName] ModelProto producerName + * @property {string|null} [producerVersion] ModelProto producerVersion + * @property {string|null} [domain] ModelProto domain + * @property {number|Long|null} [modelVersion] ModelProto modelVersion + * @property {string|null} [docString] ModelProto docString + * @property {onnx.IGraphProto|null} [graph] ModelProto graph + * @property {Array.|null} [metadataProps] ModelProto metadataProps + * @property {Array.|null} [trainingInfo] ModelProto trainingInfo + * @property {Array.|null} [functions] ModelProto functions + */ + + /** + * Constructs a new ModelProto. + * @memberof onnx + * @classdesc Represents a ModelProto. + * @implements IModelProto + * @constructor + * @param {onnx.IModelProto=} [properties] Properties to set + */ + function ModelProto(properties) { + this.opsetImport = []; + this.metadataProps = []; + this.trainingInfo = []; + this.functions = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * ModelProto irVersion. + * @member {number|Long} irVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.irVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * ModelProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.opsetImport = $util.emptyArray; + + /** + * ModelProto producerName. + * @member {string} producerName + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerName = ""; + + /** + * ModelProto producerVersion. + * @member {string} producerVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerVersion = ""; + + /** + * ModelProto domain. + * @member {string} domain + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.domain = ""; + + /** + * ModelProto modelVersion. + * @member {number|Long} modelVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.modelVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * ModelProto docString. + * @member {string} docString + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.docString = ""; + + /** + * ModelProto graph. + * @member {onnx.IGraphProto|null|undefined} graph + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.graph = null; + + /** + * ModelProto metadataProps. + * @member {Array.} metadataProps + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.metadataProps = $util.emptyArray; + + /** + * ModelProto trainingInfo. + * @member {Array.} trainingInfo + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.trainingInfo = $util.emptyArray; + + /** + * ModelProto functions. + * @member {Array.} functions + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.functions = $util.emptyArray; + + /** + * Creates a new ModelProto instance using the specified properties. + * @function create + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto=} [properties] Properties to set + * @returns {onnx.ModelProto} ModelProto instance + */ + ModelProto.create = function create(properties) { + return new ModelProto(properties); + }; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encode + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.irVersion != null && Object.hasOwnProperty.call(message, "irVersion")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.irVersion); + if (message.producerName != null && Object.hasOwnProperty.call(message, "producerName")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.producerName); + if (message.producerVersion != null && Object.hasOwnProperty.call(message, "producerVersion")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.producerVersion); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.domain); + if (message.modelVersion != null && Object.hasOwnProperty.call(message, "modelVersion")) + writer.uint32(/* id 5, wireType 0 =*/40).int64(message.modelVersion); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); + if (message.graph != null && Object.hasOwnProperty.call(message, "graph")) + $root.onnx.GraphProto.encode(message.graph, writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); + if (message.metadataProps != null && message.metadataProps.length) + for (var i = 0; i < message.metadataProps.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.metadataProps[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.trainingInfo != null && message.trainingInfo.length) + for (var i = 0; i < message.trainingInfo.length; ++i) + $root.onnx.TrainingInfoProto.encode(message.trainingInfo[i], writer.uint32(/* id 20, wireType 2 =*/162).fork()).ldelim(); + if (message.functions != null && message.functions.length) + for (var i = 0; i < message.functions.length; ++i) + $root.onnx.FunctionProto.encode(message.functions[i], writer.uint32(/* id 25, wireType 2 =*/202).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ModelProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.irVersion = reader.int64(); + break; + } + case 8: { + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.producerName = reader.string(); + break; + } + case 3: { + message.producerVersion = reader.string(); + break; + } + case 4: { + message.domain = reader.string(); + break; + } + case 5: { + message.modelVersion = reader.int64(); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + case 7: { + message.graph = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 14: { + if (!(message.metadataProps && message.metadataProps.length)) + message.metadataProps = []; + message.metadataProps.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 20: { + if (!(message.trainingInfo && message.trainingInfo.length)) + message.trainingInfo = []; + message.trainingInfo.push($root.onnx.TrainingInfoProto.decode(reader, reader.uint32())); + break; + } + case 25: { + if (!(message.functions && message.functions.length)) + message.functions = []; + message.functions.push($root.onnx.FunctionProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ModelProto message. + * @function verify + * @memberof onnx.ModelProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ModelProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (!$util.isInteger(message.irVersion) && !(message.irVersion && $util.isInteger(message.irVersion.low) && $util.isInteger(message.irVersion.high))) + return "irVersion: integer|Long expected"; + if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) + return "opsetImport." + error; + } + } + if (message.producerName != null && message.hasOwnProperty("producerName")) + if (!$util.isString(message.producerName)) + return "producerName: string expected"; + if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) + if (!$util.isString(message.producerVersion)) + return "producerVersion: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) + if (!$util.isInteger(message.modelVersion) && !(message.modelVersion && $util.isInteger(message.modelVersion.low) && $util.isInteger(message.modelVersion.high))) + return "modelVersion: integer|Long expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.graph != null && message.hasOwnProperty("graph")) { + var error = $root.onnx.GraphProto.verify(message.graph); + if (error) + return "graph." + error; + } + if (message.metadataProps != null && message.hasOwnProperty("metadataProps")) { + if (!Array.isArray(message.metadataProps)) + return "metadataProps: array expected"; + for (var i = 0; i < message.metadataProps.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.metadataProps[i]); + if (error) + return "metadataProps." + error; + } + } + if (message.trainingInfo != null && message.hasOwnProperty("trainingInfo")) { + if (!Array.isArray(message.trainingInfo)) + return "trainingInfo: array expected"; + for (var i = 0; i < message.trainingInfo.length; ++i) { + var error = $root.onnx.TrainingInfoProto.verify(message.trainingInfo[i]); + if (error) + return "trainingInfo." + error; + } + } + if (message.functions != null && message.hasOwnProperty("functions")) { + if (!Array.isArray(message.functions)) + return "functions: array expected"; + for (var i = 0; i < message.functions.length; ++i) { + var error = $root.onnx.FunctionProto.verify(message.functions[i]); + if (error) + return "functions." + error; + } + } + return null; + }; + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ModelProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ModelProto} ModelProto + */ + ModelProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ModelProto) + return object; + var message = new $root.onnx.ModelProto(); + if (object.irVersion != null) + if ($util.Long) + (message.irVersion = $util.Long.fromValue(object.irVersion)).unsigned = false; + else if (typeof object.irVersion === "string") + message.irVersion = parseInt(object.irVersion, 10); + else if (typeof object.irVersion === "number") + message.irVersion = object.irVersion; + else if (typeof object.irVersion === "object") + message.irVersion = new $util.LongBits(object.irVersion.low >>> 0, object.irVersion.high >>> 0).toNumber(); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.ModelProto.opsetImport: array expected"); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.ModelProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.producerName != null) + message.producerName = String(object.producerName); + if (object.producerVersion != null) + message.producerVersion = String(object.producerVersion); + if (object.domain != null) + message.domain = String(object.domain); + if (object.modelVersion != null) + if ($util.Long) + (message.modelVersion = $util.Long.fromValue(object.modelVersion)).unsigned = false; + else if (typeof object.modelVersion === "string") + message.modelVersion = parseInt(object.modelVersion, 10); + else if (typeof object.modelVersion === "number") + message.modelVersion = object.modelVersion; + else if (typeof object.modelVersion === "object") + message.modelVersion = new $util.LongBits(object.modelVersion.low >>> 0, object.modelVersion.high >>> 0).toNumber(); + if (object.docString != null) + message.docString = String(object.docString); + if (object.graph != null) { + if (typeof object.graph !== "object") + throw TypeError(".onnx.ModelProto.graph: object expected"); + message.graph = $root.onnx.GraphProto.fromObject(object.graph); + } + if (object.metadataProps) { + if (!Array.isArray(object.metadataProps)) + throw TypeError(".onnx.ModelProto.metadataProps: array expected"); + message.metadataProps = []; + for (var i = 0; i < object.metadataProps.length; ++i) { + if (typeof object.metadataProps[i] !== "object") + throw TypeError(".onnx.ModelProto.metadataProps: object expected"); + message.metadataProps[i] = $root.onnx.StringStringEntryProto.fromObject(object.metadataProps[i]); + } + } + if (object.trainingInfo) { + if (!Array.isArray(object.trainingInfo)) + throw TypeError(".onnx.ModelProto.trainingInfo: array expected"); + message.trainingInfo = []; + for (var i = 0; i < object.trainingInfo.length; ++i) { + if (typeof object.trainingInfo[i] !== "object") + throw TypeError(".onnx.ModelProto.trainingInfo: object expected"); + message.trainingInfo[i] = $root.onnx.TrainingInfoProto.fromObject(object.trainingInfo[i]); + } + } + if (object.functions) { + if (!Array.isArray(object.functions)) + throw TypeError(".onnx.ModelProto.functions: array expected"); + message.functions = []; + for (var i = 0; i < object.functions.length; ++i) { + if (typeof object.functions[i] !== "object") + throw TypeError(".onnx.ModelProto.functions: object expected"); + message.functions[i] = $root.onnx.FunctionProto.fromObject(object.functions[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ModelProto + * @static + * @param {onnx.ModelProto} message ModelProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ModelProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.opsetImport = []; + object.metadataProps = []; + object.trainingInfo = []; + object.functions = []; + } + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.irVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.irVersion = options.longs === String ? "0" : 0; + object.producerName = ""; + object.producerVersion = ""; + object.domain = ""; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.modelVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.modelVersion = options.longs === String ? "0" : 0; + object.docString = ""; + object.graph = null; + } + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (typeof message.irVersion === "number") + object.irVersion = options.longs === String ? String(message.irVersion) : message.irVersion; + else + object.irVersion = options.longs === String ? $util.Long.prototype.toString.call(message.irVersion) : options.longs === Number ? new $util.LongBits(message.irVersion.low >>> 0, message.irVersion.high >>> 0).toNumber() : message.irVersion; + if (message.producerName != null && message.hasOwnProperty("producerName")) + object.producerName = message.producerName; + if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) + object.producerVersion = message.producerVersion; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) + if (typeof message.modelVersion === "number") + object.modelVersion = options.longs === String ? String(message.modelVersion) : message.modelVersion; + else + object.modelVersion = options.longs === String ? $util.Long.prototype.toString.call(message.modelVersion) : options.longs === Number ? new $util.LongBits(message.modelVersion.low >>> 0, message.modelVersion.high >>> 0).toNumber() : message.modelVersion; + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.graph != null && message.hasOwnProperty("graph")) + object.graph = $root.onnx.GraphProto.toObject(message.graph, options); + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.metadataProps && message.metadataProps.length) { + object.metadataProps = []; + for (var j = 0; j < message.metadataProps.length; ++j) + object.metadataProps[j] = $root.onnx.StringStringEntryProto.toObject(message.metadataProps[j], options); + } + if (message.trainingInfo && message.trainingInfo.length) { + object.trainingInfo = []; + for (var j = 0; j < message.trainingInfo.length; ++j) + object.trainingInfo[j] = $root.onnx.TrainingInfoProto.toObject(message.trainingInfo[j], options); + } + if (message.functions && message.functions.length) { + object.functions = []; + for (var j = 0; j < message.functions.length; ++j) + object.functions[j] = $root.onnx.FunctionProto.toObject(message.functions[j], options); + } + return object; + }; + + /** + * Converts this ModelProto to JSON. + * @function toJSON + * @memberof onnx.ModelProto + * @instance + * @returns {Object.} JSON object + */ + ModelProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ModelProto + * @function getTypeUrl + * @memberof onnx.ModelProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ModelProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.ModelProto"; + }; + + return ModelProto; + })(); + + onnx.StringStringEntryProto = (function() { + + /** + * Properties of a StringStringEntryProto. + * @memberof onnx + * @interface IStringStringEntryProto + * @property {string|null} [key] StringStringEntryProto key + * @property {string|null} [value] StringStringEntryProto value + */ + + /** + * Constructs a new StringStringEntryProto. + * @memberof onnx + * @classdesc Represents a StringStringEntryProto. + * @implements IStringStringEntryProto + * @constructor + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + */ + function StringStringEntryProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * StringStringEntryProto key. + * @member {string} key + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.key = ""; + + /** + * StringStringEntryProto value. + * @member {string} value + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.value = ""; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @function create + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + * @returns {onnx.StringStringEntryProto} StringStringEntryProto instance + */ + StringStringEntryProto.create = function create(properties) { + return new StringStringEntryProto(properties); + }; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encode + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.key != null && Object.hasOwnProperty.call(message, "key")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.key); + if (message.value != null && Object.hasOwnProperty.call(message, "value")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.value); + return writer; + }; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.StringStringEntryProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.key = reader.string(); + break; + } + case 2: { + message.value = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a StringStringEntryProto message. + * @function verify + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + StringStringEntryProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.key != null && message.hasOwnProperty("key")) + if (!$util.isString(message.key)) + return "key: string expected"; + if (message.value != null && message.hasOwnProperty("value")) + if (!$util.isString(message.value)) + return "value: string expected"; + return null; + }; + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + */ + StringStringEntryProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.StringStringEntryProto) + return object; + var message = new $root.onnx.StringStringEntryProto(); + if (object.key != null) + message.key = String(object.key); + if (object.value != null) + message.value = String(object.value); + return message; + }; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.StringStringEntryProto} message StringStringEntryProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + StringStringEntryProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.key = ""; + object.value = ""; + } + if (message.key != null && message.hasOwnProperty("key")) + object.key = message.key; + if (message.value != null && message.hasOwnProperty("value")) + object.value = message.value; + return object; + }; + + /** + * Converts this StringStringEntryProto to JSON. + * @function toJSON + * @memberof onnx.StringStringEntryProto + * @instance + * @returns {Object.} JSON object + */ + StringStringEntryProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for StringStringEntryProto + * @function getTypeUrl + * @memberof onnx.StringStringEntryProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + StringStringEntryProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.StringStringEntryProto"; + }; + + return StringStringEntryProto; + })(); + + onnx.TensorAnnotation = (function() { + + /** + * Properties of a TensorAnnotation. + * @memberof onnx + * @interface ITensorAnnotation + * @property {string|null} [tensorName] TensorAnnotation tensorName + * @property {Array.|null} [quantParameterTensorNames] TensorAnnotation quantParameterTensorNames + */ + + /** + * Constructs a new TensorAnnotation. + * @memberof onnx + * @classdesc Represents a TensorAnnotation. + * @implements ITensorAnnotation + * @constructor + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + */ + function TensorAnnotation(properties) { + this.quantParameterTensorNames = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorAnnotation tensorName. + * @member {string} tensorName + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.tensorName = ""; + + /** + * TensorAnnotation quantParameterTensorNames. + * @member {Array.} quantParameterTensorNames + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.quantParameterTensorNames = $util.emptyArray; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @function create + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + * @returns {onnx.TensorAnnotation} TensorAnnotation instance + */ + TensorAnnotation.create = function create(properties) { + return new TensorAnnotation(properties); + }; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encode + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.tensorName != null && Object.hasOwnProperty.call(message, "tensorName")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.tensorName); + if (message.quantParameterTensorNames != null && message.quantParameterTensorNames.length) + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.quantParameterTensorNames[i], writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorAnnotation(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorName = reader.string(); + break; + } + case 2: { + if (!(message.quantParameterTensorNames && message.quantParameterTensorNames.length)) + message.quantParameterTensorNames = []; + message.quantParameterTensorNames.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorAnnotation message. + * @function verify + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorAnnotation.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + if (!$util.isString(message.tensorName)) + return "tensorName: string expected"; + if (message.quantParameterTensorNames != null && message.hasOwnProperty("quantParameterTensorNames")) { + if (!Array.isArray(message.quantParameterTensorNames)) + return "quantParameterTensorNames: array expected"; + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.quantParameterTensorNames[i]); + if (error) + return "quantParameterTensorNames." + error; + } + } + return null; + }; + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorAnnotation} TensorAnnotation + */ + TensorAnnotation.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorAnnotation) + return object; + var message = new $root.onnx.TensorAnnotation(); + if (object.tensorName != null) + message.tensorName = String(object.tensorName); + if (object.quantParameterTensorNames) { + if (!Array.isArray(object.quantParameterTensorNames)) + throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: array expected"); + message.quantParameterTensorNames = []; + for (var i = 0; i < object.quantParameterTensorNames.length; ++i) { + if (typeof object.quantParameterTensorNames[i] !== "object") + throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: object expected"); + message.quantParameterTensorNames[i] = $root.onnx.StringStringEntryProto.fromObject(object.quantParameterTensorNames[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.TensorAnnotation} message TensorAnnotation + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorAnnotation.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.quantParameterTensorNames = []; + if (options.defaults) + object.tensorName = ""; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + object.tensorName = message.tensorName; + if (message.quantParameterTensorNames && message.quantParameterTensorNames.length) { + object.quantParameterTensorNames = []; + for (var j = 0; j < message.quantParameterTensorNames.length; ++j) + object.quantParameterTensorNames[j] = $root.onnx.StringStringEntryProto.toObject(message.quantParameterTensorNames[j], options); + } + return object; + }; + + /** + * Converts this TensorAnnotation to JSON. + * @function toJSON + * @memberof onnx.TensorAnnotation + * @instance + * @returns {Object.} JSON object + */ + TensorAnnotation.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorAnnotation + * @function getTypeUrl + * @memberof onnx.TensorAnnotation + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorAnnotation.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorAnnotation"; + }; + + return TensorAnnotation; + })(); + + onnx.GraphProto = (function() { + + /** + * Properties of a GraphProto. + * @memberof onnx + * @interface IGraphProto + * @property {Array.|null} [node] GraphProto node + * @property {string|null} [name] GraphProto name + * @property {Array.|null} [initializer] GraphProto initializer + * @property {Array.|null} [sparseInitializer] GraphProto sparseInitializer + * @property {string|null} [docString] GraphProto docString + * @property {Array.|null} [input] GraphProto input + * @property {Array.|null} [output] GraphProto output + * @property {Array.|null} [valueInfo] GraphProto valueInfo + * @property {Array.|null} [quantizationAnnotation] GraphProto quantizationAnnotation + */ + + /** + * Constructs a new GraphProto. + * @memberof onnx + * @classdesc Represents a GraphProto. + * @implements IGraphProto + * @constructor + * @param {onnx.IGraphProto=} [properties] Properties to set + */ + function GraphProto(properties) { + this.node = []; + this.initializer = []; + this.sparseInitializer = []; + this.input = []; + this.output = []; + this.valueInfo = []; + this.quantizationAnnotation = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * GraphProto node. + * @member {Array.} node + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.node = $util.emptyArray; + + /** + * GraphProto name. + * @member {string} name + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.name = ""; + + /** + * GraphProto initializer. + * @member {Array.} initializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.initializer = $util.emptyArray; + + /** + * GraphProto sparseInitializer. + * @member {Array.} sparseInitializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.sparseInitializer = $util.emptyArray; + + /** + * GraphProto docString. + * @member {string} docString + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.docString = ""; + + /** + * GraphProto input. + * @member {Array.} input + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.input = $util.emptyArray; + + /** + * GraphProto output. + * @member {Array.} output + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.output = $util.emptyArray; + + /** + * GraphProto valueInfo. + * @member {Array.} valueInfo + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.valueInfo = $util.emptyArray; + + /** + * GraphProto quantizationAnnotation. + * @member {Array.} quantizationAnnotation + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.quantizationAnnotation = $util.emptyArray; + + /** + * Creates a new GraphProto instance using the specified properties. + * @function create + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto=} [properties] Properties to set + * @returns {onnx.GraphProto} GraphProto instance + */ + GraphProto.create = function create(properties) { + return new GraphProto(properties); + }; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encode + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.name); + if (message.initializer != null && message.initializer.length) + for (var i = 0; i < message.initializer.length; ++i) + $root.onnx.TensorProto.encode(message.initializer[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 10, wireType 2 =*/82).string(message.docString); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + $root.onnx.ValueInfoProto.encode(message.input[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + $root.onnx.ValueInfoProto.encode(message.output[i], writer.uint32(/* id 12, wireType 2 =*/98).fork()).ldelim(); + if (message.valueInfo != null && message.valueInfo.length) + for (var i = 0; i < message.valueInfo.length; ++i) + $root.onnx.ValueInfoProto.encode(message.valueInfo[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); + if (message.quantizationAnnotation != null && message.quantizationAnnotation.length) + for (var i = 0; i < message.quantizationAnnotation.length; ++i) + $root.onnx.TensorAnnotation.encode(message.quantizationAnnotation[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.sparseInitializer != null && message.sparseInitializer.length) + for (var i = 0; i < message.sparseInitializer.length; ++i) + $root.onnx.SparseTensorProto.encode(message.sparseInitializer[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.GraphProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.node && message.node.length)) + message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.name = reader.string(); + break; + } + case 5: { + if (!(message.initializer && message.initializer.length)) + message.initializer = []; + message.initializer.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.sparseInitializer && message.sparseInitializer.length)) + message.sparseInitializer = []; + message.sparseInitializer.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.docString = reader.string(); + break; + } + case 11: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 12: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 13: { + if (!(message.valueInfo && message.valueInfo.length)) + message.valueInfo = []; + message.valueInfo.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 14: { + if (!(message.quantizationAnnotation && message.quantizationAnnotation.length)) + message.quantizationAnnotation = []; + message.quantizationAnnotation.push($root.onnx.TensorAnnotation.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a GraphProto message. + * @function verify + * @memberof onnx.GraphProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + GraphProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) + return "node: array expected"; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) + return "node." + error; + } + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.initializer != null && message.hasOwnProperty("initializer")) { + if (!Array.isArray(message.initializer)) + return "initializer: array expected"; + for (var i = 0; i < message.initializer.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.initializer[i]); + if (error) + return "initializer." + error; + } + } + if (message.sparseInitializer != null && message.hasOwnProperty("sparseInitializer")) { + if (!Array.isArray(message.sparseInitializer)) + return "sparseInitializer: array expected"; + for (var i = 0; i < message.sparseInitializer.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseInitializer[i]); + if (error) + return "sparseInitializer." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.input[i]); + if (error) + return "input." + error; + } + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.output[i]); + if (error) + return "output." + error; + } + } + if (message.valueInfo != null && message.hasOwnProperty("valueInfo")) { + if (!Array.isArray(message.valueInfo)) + return "valueInfo: array expected"; + for (var i = 0; i < message.valueInfo.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.valueInfo[i]); + if (error) + return "valueInfo." + error; + } + } + if (message.quantizationAnnotation != null && message.hasOwnProperty("quantizationAnnotation")) { + if (!Array.isArray(message.quantizationAnnotation)) + return "quantizationAnnotation: array expected"; + for (var i = 0; i < message.quantizationAnnotation.length; ++i) { + var error = $root.onnx.TensorAnnotation.verify(message.quantizationAnnotation[i]); + if (error) + return "quantizationAnnotation." + error; + } + } + return null; + }; + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.GraphProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.GraphProto} GraphProto + */ + GraphProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.GraphProto) + return object; + var message = new $root.onnx.GraphProto(); + if (object.node) { + if (!Array.isArray(object.node)) + throw TypeError(".onnx.GraphProto.node: array expected"); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.GraphProto.node: object expected"); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.name != null) + message.name = String(object.name); + if (object.initializer) { + if (!Array.isArray(object.initializer)) + throw TypeError(".onnx.GraphProto.initializer: array expected"); + message.initializer = []; + for (var i = 0; i < object.initializer.length; ++i) { + if (typeof object.initializer[i] !== "object") + throw TypeError(".onnx.GraphProto.initializer: object expected"); + message.initializer[i] = $root.onnx.TensorProto.fromObject(object.initializer[i]); + } + } + if (object.sparseInitializer) { + if (!Array.isArray(object.sparseInitializer)) + throw TypeError(".onnx.GraphProto.sparseInitializer: array expected"); + message.sparseInitializer = []; + for (var i = 0; i < object.sparseInitializer.length; ++i) { + if (typeof object.sparseInitializer[i] !== "object") + throw TypeError(".onnx.GraphProto.sparseInitializer: object expected"); + message.sparseInitializer[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseInitializer[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.GraphProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) { + if (typeof object.input[i] !== "object") + throw TypeError(".onnx.GraphProto.input: object expected"); + message.input[i] = $root.onnx.ValueInfoProto.fromObject(object.input[i]); + } + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.GraphProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) { + if (typeof object.output[i] !== "object") + throw TypeError(".onnx.GraphProto.output: object expected"); + message.output[i] = $root.onnx.ValueInfoProto.fromObject(object.output[i]); + } + } + if (object.valueInfo) { + if (!Array.isArray(object.valueInfo)) + throw TypeError(".onnx.GraphProto.valueInfo: array expected"); + message.valueInfo = []; + for (var i = 0; i < object.valueInfo.length; ++i) { + if (typeof object.valueInfo[i] !== "object") + throw TypeError(".onnx.GraphProto.valueInfo: object expected"); + message.valueInfo[i] = $root.onnx.ValueInfoProto.fromObject(object.valueInfo[i]); + } + } + if (object.quantizationAnnotation) { + if (!Array.isArray(object.quantizationAnnotation)) + throw TypeError(".onnx.GraphProto.quantizationAnnotation: array expected"); + message.quantizationAnnotation = []; + for (var i = 0; i < object.quantizationAnnotation.length; ++i) { + if (typeof object.quantizationAnnotation[i] !== "object") + throw TypeError(".onnx.GraphProto.quantizationAnnotation: object expected"); + message.quantizationAnnotation[i] = $root.onnx.TensorAnnotation.fromObject(object.quantizationAnnotation[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.GraphProto + * @static + * @param {onnx.GraphProto} message GraphProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + GraphProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.node = []; + object.initializer = []; + object.input = []; + object.output = []; + object.valueInfo = []; + object.quantizationAnnotation = []; + object.sparseInitializer = []; + } + if (options.defaults) { + object.name = ""; + object.docString = ""; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.initializer && message.initializer.length) { + object.initializer = []; + for (var j = 0; j < message.initializer.length; ++j) + object.initializer[j] = $root.onnx.TensorProto.toObject(message.initializer[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = $root.onnx.ValueInfoProto.toObject(message.input[j], options); + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = $root.onnx.ValueInfoProto.toObject(message.output[j], options); + } + if (message.valueInfo && message.valueInfo.length) { + object.valueInfo = []; + for (var j = 0; j < message.valueInfo.length; ++j) + object.valueInfo[j] = $root.onnx.ValueInfoProto.toObject(message.valueInfo[j], options); + } + if (message.quantizationAnnotation && message.quantizationAnnotation.length) { + object.quantizationAnnotation = []; + for (var j = 0; j < message.quantizationAnnotation.length; ++j) + object.quantizationAnnotation[j] = $root.onnx.TensorAnnotation.toObject(message.quantizationAnnotation[j], options); + } + if (message.sparseInitializer && message.sparseInitializer.length) { + object.sparseInitializer = []; + for (var j = 0; j < message.sparseInitializer.length; ++j) + object.sparseInitializer[j] = $root.onnx.SparseTensorProto.toObject(message.sparseInitializer[j], options); + } + return object; + }; + + /** + * Converts this GraphProto to JSON. + * @function toJSON + * @memberof onnx.GraphProto + * @instance + * @returns {Object.} JSON object + */ + GraphProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for GraphProto + * @function getTypeUrl + * @memberof onnx.GraphProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + GraphProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.GraphProto"; + }; + + return GraphProto; + })(); + + onnx.TensorProto = (function() { + + /** + * Properties of a TensorProto. + * @memberof onnx + * @interface ITensorProto + * @property {Array.|null} [dims] TensorProto dims + * @property {number|null} [dataType] TensorProto dataType + * @property {onnx.TensorProto.ISegment|null} [segment] TensorProto segment + * @property {Array.|null} [floatData] TensorProto floatData + * @property {Array.|null} [int32Data] TensorProto int32Data + * @property {Array.|null} [stringData] TensorProto stringData + * @property {Array.|null} [int64Data] TensorProto int64Data + * @property {string|null} [name] TensorProto name + * @property {string|null} [docString] TensorProto docString + * @property {Uint8Array|null} [rawData] TensorProto rawData + * @property {Array.|null} [externalData] TensorProto externalData + * @property {onnx.TensorProto.DataLocation|null} [dataLocation] TensorProto dataLocation + * @property {Array.|null} [doubleData] TensorProto doubleData + * @property {Array.|null} [uint64Data] TensorProto uint64Data + */ + + /** + * Constructs a new TensorProto. + * @memberof onnx + * @classdesc Represents a TensorProto. + * @implements ITensorProto + * @constructor + * @param {onnx.ITensorProto=} [properties] Properties to set + */ + function TensorProto(properties) { + this.dims = []; + this.floatData = []; + this.int32Data = []; + this.stringData = []; + this.int64Data = []; + this.externalData = []; + this.doubleData = []; + this.uint64Data = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorProto dims. + * @member {Array.} dims + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dims = $util.emptyArray; + + /** + * TensorProto dataType. + * @member {number} dataType + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataType = 0; + + /** + * TensorProto segment. + * @member {onnx.TensorProto.ISegment|null|undefined} segment + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.segment = null; + + /** + * TensorProto floatData. + * @member {Array.} floatData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.floatData = $util.emptyArray; + + /** + * TensorProto int32Data. + * @member {Array.} int32Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int32Data = $util.emptyArray; + + /** + * TensorProto stringData. + * @member {Array.} stringData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.stringData = $util.emptyArray; + + /** + * TensorProto int64Data. + * @member {Array.} int64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int64Data = $util.emptyArray; + + /** + * TensorProto name. + * @member {string} name + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.name = ""; + + /** + * TensorProto docString. + * @member {string} docString + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.docString = ""; + + /** + * TensorProto rawData. + * @member {Uint8Array} rawData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.rawData = $util.newBuffer([]); + + /** + * TensorProto externalData. + * @member {Array.} externalData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.externalData = $util.emptyArray; + + /** + * TensorProto dataLocation. + * @member {onnx.TensorProto.DataLocation} dataLocation + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataLocation = 0; + + /** + * TensorProto doubleData. + * @member {Array.} doubleData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.doubleData = $util.emptyArray; + + /** + * TensorProto uint64Data. + * @member {Array.} uint64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.uint64Data = $util.emptyArray; + + /** + * Creates a new TensorProto instance using the specified properties. + * @function create + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto=} [properties] Properties to set + * @returns {onnx.TensorProto} TensorProto instance + */ + TensorProto.create = function create(properties) { + return new TensorProto(properties); + }; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 1, wireType 2 =*/10).fork(); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); + writer.ldelim(); + } + if (message.dataType != null && Object.hasOwnProperty.call(message, "dataType")) + writer.uint32(/* id 2, wireType 0 =*/16).int32(message.dataType); + if (message.segment != null && Object.hasOwnProperty.call(message, "segment")) + $root.onnx.TensorProto.Segment.encode(message.segment, writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); + if (message.floatData != null && message.floatData.length) { + writer.uint32(/* id 4, wireType 2 =*/34).fork(); + for (var i = 0; i < message.floatData.length; ++i) + writer.float(message.floatData[i]); + writer.ldelim(); + } + if (message.int32Data != null && message.int32Data.length) { + writer.uint32(/* id 5, wireType 2 =*/42).fork(); + for (var i = 0; i < message.int32Data.length; ++i) + writer.int32(message.int32Data[i]); + writer.ldelim(); + } + if (message.stringData != null && message.stringData.length) + for (var i = 0; i < message.stringData.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/50).bytes(message.stringData[i]); + if (message.int64Data != null && message.int64Data.length) { + writer.uint32(/* id 7, wireType 2 =*/58).fork(); + for (var i = 0; i < message.int64Data.length; ++i) + writer.int64(message.int64Data[i]); + writer.ldelim(); + } + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 8, wireType 2 =*/66).string(message.name); + if (message.rawData != null && Object.hasOwnProperty.call(message, "rawData")) + writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.rawData); + if (message.doubleData != null && message.doubleData.length) { + writer.uint32(/* id 10, wireType 2 =*/82).fork(); + for (var i = 0; i < message.doubleData.length; ++i) + writer.double(message.doubleData[i]); + writer.ldelim(); + } + if (message.uint64Data != null && message.uint64Data.length) { + writer.uint32(/* id 11, wireType 2 =*/90).fork(); + for (var i = 0; i < message.uint64Data.length; ++i) + writer.uint64(message.uint64Data[i]); + writer.ldelim(); + } + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 12, wireType 2 =*/98).string(message.docString); + if (message.externalData != null && message.externalData.length) + for (var i = 0; i < message.externalData.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.externalData[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); + if (message.dataLocation != null && Object.hasOwnProperty.call(message, "dataLocation")) + writer.uint32(/* id 14, wireType 0 =*/112).int32(message.dataLocation); + return writer; + }; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dims && message.dims.length)) + message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.dims.push(reader.int64()); + } else + message.dims.push(reader.int64()); + break; + } + case 2: { + message.dataType = reader.int32(); + break; + } + case 3: { + message.segment = $root.onnx.TensorProto.Segment.decode(reader, reader.uint32()); + break; + } + case 4: { + if (!(message.floatData && message.floatData.length)) + message.floatData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.floatData.push(reader.float()); + } else + message.floatData.push(reader.float()); + break; + } + case 5: { + if (!(message.int32Data && message.int32Data.length)) + message.int32Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.int32Data.push(reader.int32()); + } else + message.int32Data.push(reader.int32()); + break; + } + case 6: { + if (!(message.stringData && message.stringData.length)) + message.stringData = []; + message.stringData.push(reader.bytes()); + break; + } + case 7: { + if (!(message.int64Data && message.int64Data.length)) + message.int64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.int64Data.push(reader.int64()); + } else + message.int64Data.push(reader.int64()); + break; + } + case 8: { + message.name = reader.string(); + break; + } + case 12: { + message.docString = reader.string(); + break; + } + case 9: { + message.rawData = reader.bytes(); + break; + } + case 13: { + if (!(message.externalData && message.externalData.length)) + message.externalData = []; + message.externalData.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 14: { + message.dataLocation = reader.int32(); + break; + } + case 10: { + if (!(message.doubleData && message.doubleData.length)) + message.doubleData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.doubleData.push(reader.double()); + } else + message.doubleData.push(reader.double()); + break; + } + case 11: { + if (!(message.uint64Data && message.uint64Data.length)) + message.uint64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.uint64Data.push(reader.uint64()); + } else + message.uint64Data.push(reader.uint64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorProto message. + * @function verify + * @memberof onnx.TensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) + return "dims: array expected"; + for (var i = 0; i < message.dims.length; ++i) + if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) + return "dims: integer|Long[] expected"; + } + if (message.dataType != null && message.hasOwnProperty("dataType")) + if (!$util.isInteger(message.dataType)) + return "dataType: integer expected"; + if (message.segment != null && message.hasOwnProperty("segment")) { + var error = $root.onnx.TensorProto.Segment.verify(message.segment); + if (error) + return "segment." + error; + } + if (message.floatData != null && message.hasOwnProperty("floatData")) { + if (!Array.isArray(message.floatData)) + return "floatData: array expected"; + for (var i = 0; i < message.floatData.length; ++i) + if (typeof message.floatData[i] !== "number") + return "floatData: number[] expected"; + } + if (message.int32Data != null && message.hasOwnProperty("int32Data")) { + if (!Array.isArray(message.int32Data)) + return "int32Data: array expected"; + for (var i = 0; i < message.int32Data.length; ++i) + if (!$util.isInteger(message.int32Data[i])) + return "int32Data: integer[] expected"; + } + if (message.stringData != null && message.hasOwnProperty("stringData")) { + if (!Array.isArray(message.stringData)) + return "stringData: array expected"; + for (var i = 0; i < message.stringData.length; ++i) + if (!(message.stringData[i] && typeof message.stringData[i].length === "number" || $util.isString(message.stringData[i]))) + return "stringData: buffer[] expected"; + } + if (message.int64Data != null && message.hasOwnProperty("int64Data")) { + if (!Array.isArray(message.int64Data)) + return "int64Data: array expected"; + for (var i = 0; i < message.int64Data.length; ++i) + if (!$util.isInteger(message.int64Data[i]) && !(message.int64Data[i] && $util.isInteger(message.int64Data[i].low) && $util.isInteger(message.int64Data[i].high))) + return "int64Data: integer|Long[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.rawData != null && message.hasOwnProperty("rawData")) + if (!(message.rawData && typeof message.rawData.length === "number" || $util.isString(message.rawData))) + return "rawData: buffer expected"; + if (message.externalData != null && message.hasOwnProperty("externalData")) { + if (!Array.isArray(message.externalData)) + return "externalData: array expected"; + for (var i = 0; i < message.externalData.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.externalData[i]); + if (error) + return "externalData." + error; + } + } + if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) + switch (message.dataLocation) { + default: + return "dataLocation: enum value expected"; + case 0: + case 1: + break; + } + if (message.doubleData != null && message.hasOwnProperty("doubleData")) { + if (!Array.isArray(message.doubleData)) + return "doubleData: array expected"; + for (var i = 0; i < message.doubleData.length; ++i) + if (typeof message.doubleData[i] !== "number") + return "doubleData: number[] expected"; + } + if (message.uint64Data != null && message.hasOwnProperty("uint64Data")) { + if (!Array.isArray(message.uint64Data)) + return "uint64Data: array expected"; + for (var i = 0; i < message.uint64Data.length; ++i) + if (!$util.isInteger(message.uint64Data[i]) && !(message.uint64Data[i] && $util.isInteger(message.uint64Data[i].low) && $util.isInteger(message.uint64Data[i].high))) + return "uint64Data: integer|Long[] expected"; + } + return null; + }; + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto} TensorProto + */ + TensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto) + return object; + var message = new $root.onnx.TensorProto(); + if (object.dims) { + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.TensorProto.dims: array expected"); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + if (object.dataType != null) + message.dataType = object.dataType | 0; + if (object.segment != null) { + if (typeof object.segment !== "object") + throw TypeError(".onnx.TensorProto.segment: object expected"); + message.segment = $root.onnx.TensorProto.Segment.fromObject(object.segment); + } + if (object.floatData) { + if (!Array.isArray(object.floatData)) + throw TypeError(".onnx.TensorProto.floatData: array expected"); + message.floatData = []; + for (var i = 0; i < object.floatData.length; ++i) + message.floatData[i] = Number(object.floatData[i]); + } + if (object.int32Data) { + if (!Array.isArray(object.int32Data)) + throw TypeError(".onnx.TensorProto.int32Data: array expected"); + message.int32Data = []; + for (var i = 0; i < object.int32Data.length; ++i) + message.int32Data[i] = object.int32Data[i] | 0; + } + if (object.stringData) { + if (!Array.isArray(object.stringData)) + throw TypeError(".onnx.TensorProto.stringData: array expected"); + message.stringData = []; + for (var i = 0; i < object.stringData.length; ++i) + if (typeof object.stringData[i] === "string") + $util.base64.decode(object.stringData[i], message.stringData[i] = $util.newBuffer($util.base64.length(object.stringData[i])), 0); + else if (object.stringData[i].length >= 0) + message.stringData[i] = object.stringData[i]; + } + if (object.int64Data) { + if (!Array.isArray(object.int64Data)) + throw TypeError(".onnx.TensorProto.int64Data: array expected"); + message.int64Data = []; + for (var i = 0; i < object.int64Data.length; ++i) + if ($util.Long) + (message.int64Data[i] = $util.Long.fromValue(object.int64Data[i])).unsigned = false; + else if (typeof object.int64Data[i] === "string") + message.int64Data[i] = parseInt(object.int64Data[i], 10); + else if (typeof object.int64Data[i] === "number") + message.int64Data[i] = object.int64Data[i]; + else if (typeof object.int64Data[i] === "object") + message.int64Data[i] = new $util.LongBits(object.int64Data[i].low >>> 0, object.int64Data[i].high >>> 0).toNumber(); + } + if (object.name != null) + message.name = String(object.name); + if (object.docString != null) + message.docString = String(object.docString); + if (object.rawData != null) + if (typeof object.rawData === "string") + $util.base64.decode(object.rawData, message.rawData = $util.newBuffer($util.base64.length(object.rawData)), 0); + else if (object.rawData.length >= 0) + message.rawData = object.rawData; + if (object.externalData) { + if (!Array.isArray(object.externalData)) + throw TypeError(".onnx.TensorProto.externalData: array expected"); + message.externalData = []; + for (var i = 0; i < object.externalData.length; ++i) { + if (typeof object.externalData[i] !== "object") + throw TypeError(".onnx.TensorProto.externalData: object expected"); + message.externalData[i] = $root.onnx.StringStringEntryProto.fromObject(object.externalData[i]); + } + } + switch (object.dataLocation) { + default: + if (typeof object.dataLocation === "number") { + message.dataLocation = object.dataLocation; + break; + } + break; + case "DEFAULT": + case 0: + message.dataLocation = 0; + break; + case "EXTERNAL": + case 1: + message.dataLocation = 1; + break; + } + if (object.doubleData) { + if (!Array.isArray(object.doubleData)) + throw TypeError(".onnx.TensorProto.doubleData: array expected"); + message.doubleData = []; + for (var i = 0; i < object.doubleData.length; ++i) + message.doubleData[i] = Number(object.doubleData[i]); + } + if (object.uint64Data) { + if (!Array.isArray(object.uint64Data)) + throw TypeError(".onnx.TensorProto.uint64Data: array expected"); + message.uint64Data = []; + for (var i = 0; i < object.uint64Data.length; ++i) + if ($util.Long) + (message.uint64Data[i] = $util.Long.fromValue(object.uint64Data[i])).unsigned = true; + else if (typeof object.uint64Data[i] === "string") + message.uint64Data[i] = parseInt(object.uint64Data[i], 10); + else if (typeof object.uint64Data[i] === "number") + message.uint64Data[i] = object.uint64Data[i]; + else if (typeof object.uint64Data[i] === "object") + message.uint64Data[i] = new $util.LongBits(object.uint64Data[i].low >>> 0, object.uint64Data[i].high >>> 0).toNumber(true); + } + return message; + }; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto + * @static + * @param {onnx.TensorProto} message TensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.dims = []; + object.floatData = []; + object.int32Data = []; + object.stringData = []; + object.int64Data = []; + object.doubleData = []; + object.uint64Data = []; + object.externalData = []; + } + if (options.defaults) { + object.dataType = 0; + object.segment = null; + object.name = ""; + if (options.bytes === String) + object.rawData = ""; + else { + object.rawData = []; + if (options.bytes !== Array) + object.rawData = $util.newBuffer(object.rawData); + } + object.docString = ""; + object.dataLocation = options.enums === String ? "DEFAULT" : 0; + } + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === "number") + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; + } + if (message.dataType != null && message.hasOwnProperty("dataType")) + object.dataType = message.dataType; + if (message.segment != null && message.hasOwnProperty("segment")) + object.segment = $root.onnx.TensorProto.Segment.toObject(message.segment, options); + if (message.floatData && message.floatData.length) { + object.floatData = []; + for (var j = 0; j < message.floatData.length; ++j) + object.floatData[j] = options.json && !isFinite(message.floatData[j]) ? String(message.floatData[j]) : message.floatData[j]; + } + if (message.int32Data && message.int32Data.length) { + object.int32Data = []; + for (var j = 0; j < message.int32Data.length; ++j) + object.int32Data[j] = message.int32Data[j]; + } + if (message.stringData && message.stringData.length) { + object.stringData = []; + for (var j = 0; j < message.stringData.length; ++j) + object.stringData[j] = options.bytes === String ? $util.base64.encode(message.stringData[j], 0, message.stringData[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.stringData[j]) : message.stringData[j]; + } + if (message.int64Data && message.int64Data.length) { + object.int64Data = []; + for (var j = 0; j < message.int64Data.length; ++j) + if (typeof message.int64Data[j] === "number") + object.int64Data[j] = options.longs === String ? String(message.int64Data[j]) : message.int64Data[j]; + else + object.int64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.int64Data[j]) : options.longs === Number ? new $util.LongBits(message.int64Data[j].low >>> 0, message.int64Data[j].high >>> 0).toNumber() : message.int64Data[j]; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.rawData != null && message.hasOwnProperty("rawData")) + object.rawData = options.bytes === String ? $util.base64.encode(message.rawData, 0, message.rawData.length) : options.bytes === Array ? Array.prototype.slice.call(message.rawData) : message.rawData; + if (message.doubleData && message.doubleData.length) { + object.doubleData = []; + for (var j = 0; j < message.doubleData.length; ++j) + object.doubleData[j] = options.json && !isFinite(message.doubleData[j]) ? String(message.doubleData[j]) : message.doubleData[j]; + } + if (message.uint64Data && message.uint64Data.length) { + object.uint64Data = []; + for (var j = 0; j < message.uint64Data.length; ++j) + if (typeof message.uint64Data[j] === "number") + object.uint64Data[j] = options.longs === String ? String(message.uint64Data[j]) : message.uint64Data[j]; + else + object.uint64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.uint64Data[j]) : options.longs === Number ? new $util.LongBits(message.uint64Data[j].low >>> 0, message.uint64Data[j].high >>> 0).toNumber(true) : message.uint64Data[j]; + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.externalData && message.externalData.length) { + object.externalData = []; + for (var j = 0; j < message.externalData.length; ++j) + object.externalData[j] = $root.onnx.StringStringEntryProto.toObject(message.externalData[j], options); + } + if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) + object.dataLocation = options.enums === String ? $root.onnx.TensorProto.DataLocation[message.dataLocation] === undefined ? message.dataLocation : $root.onnx.TensorProto.DataLocation[message.dataLocation] : message.dataLocation; + return object; + }; + + /** + * Converts this TensorProto to JSON. + * @function toJSON + * @memberof onnx.TensorProto + * @instance + * @returns {Object.} JSON object + */ + TensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorProto + * @function getTypeUrl + * @memberof onnx.TensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorProto"; + }; + + /** + * DataType enum. + * @name onnx.TensorProto.DataType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} UINT8=2 UINT8 value + * @property {number} INT8=3 INT8 value + * @property {number} UINT16=4 UINT16 value + * @property {number} INT16=5 INT16 value + * @property {number} INT32=6 INT32 value + * @property {number} INT64=7 INT64 value + * @property {number} STRING=8 STRING value + * @property {number} BOOL=9 BOOL value + * @property {number} FLOAT16=10 FLOAT16 value + * @property {number} DOUBLE=11 DOUBLE value + * @property {number} UINT32=12 UINT32 value + * @property {number} UINT64=13 UINT64 value + * @property {number} COMPLEX64=14 COMPLEX64 value + * @property {number} COMPLEX128=15 COMPLEX128 value + * @property {number} BFLOAT16=16 BFLOAT16 value + * @property {number} FLOAT8E4M3FN=17 FLOAT8E4M3FN value + * @property {number} FLOAT8E4M3FNUZ=18 FLOAT8E4M3FNUZ value + * @property {number} FLOAT8E5M2=19 FLOAT8E5M2 value + * @property {number} FLOAT8E5M2FNUZ=20 FLOAT8E5M2FNUZ value + */ + TensorProto.DataType = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "UNDEFINED"] = 0; + values[valuesById[1] = "FLOAT"] = 1; + values[valuesById[2] = "UINT8"] = 2; + values[valuesById[3] = "INT8"] = 3; + values[valuesById[4] = "UINT16"] = 4; + values[valuesById[5] = "INT16"] = 5; + values[valuesById[6] = "INT32"] = 6; + values[valuesById[7] = "INT64"] = 7; + values[valuesById[8] = "STRING"] = 8; + values[valuesById[9] = "BOOL"] = 9; + values[valuesById[10] = "FLOAT16"] = 10; + values[valuesById[11] = "DOUBLE"] = 11; + values[valuesById[12] = "UINT32"] = 12; + values[valuesById[13] = "UINT64"] = 13; + values[valuesById[14] = "COMPLEX64"] = 14; + values[valuesById[15] = "COMPLEX128"] = 15; + values[valuesById[16] = "BFLOAT16"] = 16; + values[valuesById[17] = "FLOAT8E4M3FN"] = 17; + values[valuesById[18] = "FLOAT8E4M3FNUZ"] = 18; + values[valuesById[19] = "FLOAT8E5M2"] = 19; + values[valuesById[20] = "FLOAT8E5M2FNUZ"] = 20; + return values; + })(); + + TensorProto.Segment = (function() { + + /** + * Properties of a Segment. + * @memberof onnx.TensorProto + * @interface ISegment + * @property {number|Long|null} [begin] Segment begin + * @property {number|Long|null} [end] Segment end + */ + + /** + * Constructs a new Segment. + * @memberof onnx.TensorProto + * @classdesc Represents a Segment. + * @implements ISegment + * @constructor + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + */ + function Segment(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Segment begin. + * @member {number|Long} begin + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.begin = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Segment end. + * @member {number|Long} end + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.end = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Creates a new Segment instance using the specified properties. + * @function create + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + * @returns {onnx.TensorProto.Segment} Segment instance + */ + Segment.create = function create(properties) { + return new Segment(properties); + }; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.begin != null && Object.hasOwnProperty.call(message, "begin")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.begin); + if (message.end != null && Object.hasOwnProperty.call(message, "end")) + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.end); + return writer; + }; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto.Segment(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.begin = reader.int64(); + break; + } + case 2: { + message.end = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Segment message. + * @function verify + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Segment.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.begin != null && message.hasOwnProperty("begin")) + if (!$util.isInteger(message.begin) && !(message.begin && $util.isInteger(message.begin.low) && $util.isInteger(message.begin.high))) + return "begin: integer|Long expected"; + if (message.end != null && message.hasOwnProperty("end")) + if (!$util.isInteger(message.end) && !(message.end && $util.isInteger(message.end.low) && $util.isInteger(message.end.high))) + return "end: integer|Long expected"; + return null; + }; + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto.Segment} Segment + */ + Segment.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto.Segment) + return object; + var message = new $root.onnx.TensorProto.Segment(); + if (object.begin != null) + if ($util.Long) + (message.begin = $util.Long.fromValue(object.begin)).unsigned = false; + else if (typeof object.begin === "string") + message.begin = parseInt(object.begin, 10); + else if (typeof object.begin === "number") + message.begin = object.begin; + else if (typeof object.begin === "object") + message.begin = new $util.LongBits(object.begin.low >>> 0, object.begin.high >>> 0).toNumber(); + if (object.end != null) + if ($util.Long) + (message.end = $util.Long.fromValue(object.end)).unsigned = false; + else if (typeof object.end === "string") + message.end = parseInt(object.end, 10); + else if (typeof object.end === "number") + message.end = object.end; + else if (typeof object.end === "object") + message.end = new $util.LongBits(object.end.low >>> 0, object.end.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.Segment} message Segment + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Segment.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.begin = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.begin = options.longs === String ? "0" : 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.end = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.end = options.longs === String ? "0" : 0; + } + if (message.begin != null && message.hasOwnProperty("begin")) + if (typeof message.begin === "number") + object.begin = options.longs === String ? String(message.begin) : message.begin; + else + object.begin = options.longs === String ? $util.Long.prototype.toString.call(message.begin) : options.longs === Number ? new $util.LongBits(message.begin.low >>> 0, message.begin.high >>> 0).toNumber() : message.begin; + if (message.end != null && message.hasOwnProperty("end")) + if (typeof message.end === "number") + object.end = options.longs === String ? String(message.end) : message.end; + else + object.end = options.longs === String ? $util.Long.prototype.toString.call(message.end) : options.longs === Number ? new $util.LongBits(message.end.low >>> 0, message.end.high >>> 0).toNumber() : message.end; + return object; + }; + + /** + * Converts this Segment to JSON. + * @function toJSON + * @memberof onnx.TensorProto.Segment + * @instance + * @returns {Object.} JSON object + */ + Segment.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Segment + * @function getTypeUrl + * @memberof onnx.TensorProto.Segment + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Segment.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorProto.Segment"; + }; + + return Segment; + })(); + + /** + * DataLocation enum. + * @name onnx.TensorProto.DataLocation + * @enum {number} + * @property {number} DEFAULT=0 DEFAULT value + * @property {number} EXTERNAL=1 EXTERNAL value + */ + TensorProto.DataLocation = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "DEFAULT"] = 0; + values[valuesById[1] = "EXTERNAL"] = 1; + return values; + })(); + + return TensorProto; + })(); + + onnx.SparseTensorProto = (function() { + + /** + * Properties of a SparseTensorProto. + * @memberof onnx + * @interface ISparseTensorProto + * @property {onnx.ITensorProto|null} [values] SparseTensorProto values + * @property {onnx.ITensorProto|null} [indices] SparseTensorProto indices + * @property {Array.|null} [dims] SparseTensorProto dims + */ + + /** + * Constructs a new SparseTensorProto. + * @memberof onnx + * @classdesc Represents a SparseTensorProto. + * @implements ISparseTensorProto + * @constructor + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + */ + function SparseTensorProto(properties) { + this.dims = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensorProto values. + * @member {onnx.ITensorProto|null|undefined} values + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.values = null; + + /** + * SparseTensorProto indices. + * @member {onnx.ITensorProto|null|undefined} indices + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.indices = null; + + /** + * SparseTensorProto dims. + * @member {Array.} dims + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.dims = $util.emptyArray; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @function create + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + * @returns {onnx.SparseTensorProto} SparseTensorProto instance + */ + SparseTensorProto.create = function create(properties) { + return new SparseTensorProto(properties); + }; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.values != null && Object.hasOwnProperty.call(message, "values")) + $root.onnx.TensorProto.encode(message.values, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.indices != null && Object.hasOwnProperty.call(message, "indices")) + $root.onnx.TensorProto.encode(message.indices, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 3, wireType 2 =*/26).fork(); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); + writer.ldelim(); + } + return writer; + }; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.SparseTensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.values = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.indices = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.dims && message.dims.length)) + message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.dims.push(reader.int64()); + } else + message.dims.push(reader.int64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensorProto message. + * @function verify + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensorProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.values != null && message.hasOwnProperty("values")) { + var error = $root.onnx.TensorProto.verify(message.values); + if (error) + return "values." + error; + } + if (message.indices != null && message.hasOwnProperty("indices")) { + var error = $root.onnx.TensorProto.verify(message.indices); + if (error) + return "indices." + error; + } + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) + return "dims: array expected"; + for (var i = 0; i < message.dims.length; ++i) + if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) + return "dims: integer|Long[] expected"; + } + return null; + }; + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.SparseTensorProto} SparseTensorProto + */ + SparseTensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.SparseTensorProto) + return object; + var message = new $root.onnx.SparseTensorProto(); + if (object.values != null) { + if (typeof object.values !== "object") + throw TypeError(".onnx.SparseTensorProto.values: object expected"); + message.values = $root.onnx.TensorProto.fromObject(object.values); + } + if (object.indices != null) { + if (typeof object.indices !== "object") + throw TypeError(".onnx.SparseTensorProto.indices: object expected"); + message.indices = $root.onnx.TensorProto.fromObject(object.indices); + } + if (object.dims) { + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.SparseTensorProto.dims: array expected"); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.SparseTensorProto} message SparseTensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensorProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.dims = []; + if (options.defaults) { + object.values = null; + object.indices = null; + } + if (message.values != null && message.hasOwnProperty("values")) + object.values = $root.onnx.TensorProto.toObject(message.values, options); + if (message.indices != null && message.hasOwnProperty("indices")) + object.indices = $root.onnx.TensorProto.toObject(message.indices, options); + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === "number") + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; + } + return object; + }; + + /** + * Converts this SparseTensorProto to JSON. + * @function toJSON + * @memberof onnx.SparseTensorProto + * @instance + * @returns {Object.} JSON object + */ + SparseTensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensorProto + * @function getTypeUrl + * @memberof onnx.SparseTensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.SparseTensorProto"; + }; + + return SparseTensorProto; + })(); + + onnx.TensorShapeProto = (function() { + + /** + * Properties of a TensorShapeProto. + * @memberof onnx + * @interface ITensorShapeProto + * @property {Array.|null} [dim] TensorShapeProto dim + */ + + /** + * Constructs a new TensorShapeProto. + * @memberof onnx + * @classdesc Represents a TensorShapeProto. + * @implements ITensorShapeProto + * @constructor + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + */ + function TensorShapeProto(properties) { + this.dim = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorShapeProto dim. + * @member {Array.} dim + * @memberof onnx.TensorShapeProto + * @instance + */ + TensorShapeProto.prototype.dim = $util.emptyArray; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + * @returns {onnx.TensorShapeProto} TensorShapeProto instance + */ + TensorShapeProto.create = function create(properties) { + return new TensorShapeProto(properties); + }; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dim != null && message.dim.length) + for (var i = 0; i < message.dim.length; ++i) + $root.onnx.TensorShapeProto.Dimension.encode(message.dim[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dim && message.dim.length)) + message.dim = []; + message.dim.push($root.onnx.TensorShapeProto.Dimension.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorShapeProto message. + * @function verify + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorShapeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dim != null && message.hasOwnProperty("dim")) { + if (!Array.isArray(message.dim)) + return "dim: array expected"; + for (var i = 0; i < message.dim.length; ++i) { + var error = $root.onnx.TensorShapeProto.Dimension.verify(message.dim[i]); + if (error) + return "dim." + error; + } + } + return null; + }; + + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto} TensorShapeProto + */ + TensorShapeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto) + return object; + var message = new $root.onnx.TensorShapeProto(); + if (object.dim) { + if (!Array.isArray(object.dim)) + throw TypeError(".onnx.TensorShapeProto.dim: array expected"); + message.dim = []; + for (var i = 0; i < object.dim.length; ++i) { + if (typeof object.dim[i] !== "object") + throw TypeError(".onnx.TensorShapeProto.dim: object expected"); + message.dim[i] = $root.onnx.TensorShapeProto.Dimension.fromObject(object.dim[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.TensorShapeProto} message TensorShapeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorShapeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.dim = []; + if (message.dim && message.dim.length) { + object.dim = []; + for (var j = 0; j < message.dim.length; ++j) + object.dim[j] = $root.onnx.TensorShapeProto.Dimension.toObject(message.dim[j], options); + } + return object; + }; + + /** + * Converts this TensorShapeProto to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto + * @instance + * @returns {Object.} JSON object + */ + TensorShapeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorShapeProto + * @function getTypeUrl + * @memberof onnx.TensorShapeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorShapeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorShapeProto"; + }; + + TensorShapeProto.Dimension = (function() { + + /** + * Properties of a Dimension. + * @memberof onnx.TensorShapeProto + * @interface IDimension + * @property {number|Long|null} [dimValue] Dimension dimValue + * @property {string|null} [dimParam] Dimension dimParam + * @property {string|null} [denotation] Dimension denotation + */ + + /** + * Constructs a new Dimension. + * @memberof onnx.TensorShapeProto + * @classdesc Represents a Dimension. + * @implements IDimension + * @constructor + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + */ + function Dimension(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Dimension dimValue. + * @member {number|Long|null|undefined} dimValue + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimValue = null; + + /** + * Dimension dimParam. + * @member {string|null|undefined} dimParam + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimParam = null; + + /** + * Dimension denotation. + * @member {string} denotation + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.denotation = ""; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * Dimension value. + * @member {"dimValue"|"dimParam"|undefined} value + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Object.defineProperty(Dimension.prototype, "value", { + get: $util.oneOfGetter($oneOfFields = ["dimValue", "dimParam"]), + set: $util.oneOfSetter($oneOfFields) + }); + + /** + * Creates a new Dimension instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + * @returns {onnx.TensorShapeProto.Dimension} Dimension instance + */ + Dimension.create = function create(properties) { + return new Dimension(properties); + }; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dimValue != null && Object.hasOwnProperty.call(message, "dimValue")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.dimValue); + if (message.dimParam != null && Object.hasOwnProperty.call(message, "dimParam")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.dimParam); + if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.denotation); + return writer; + }; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto.Dimension(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.dimValue = reader.int64(); + break; + } + case 2: { + message.dimParam = reader.string(); + break; + } + case 3: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Dimension message. + * @function verify + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Dimension.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + var properties = {}; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + properties.value = 1; + if (!$util.isInteger(message.dimValue) && !(message.dimValue && $util.isInteger(message.dimValue.low) && $util.isInteger(message.dimValue.high))) + return "dimValue: integer|Long expected"; + } + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + if (!$util.isString(message.dimParam)) + return "dimParam: string expected"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; + return null; + }; + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto.Dimension} Dimension + */ + Dimension.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto.Dimension) + return object; + var message = new $root.onnx.TensorShapeProto.Dimension(); + if (object.dimValue != null) + if ($util.Long) + (message.dimValue = $util.Long.fromValue(object.dimValue)).unsigned = false; + else if (typeof object.dimValue === "string") + message.dimValue = parseInt(object.dimValue, 10); + else if (typeof object.dimValue === "number") + message.dimValue = object.dimValue; + else if (typeof object.dimValue === "object") + message.dimValue = new $util.LongBits(object.dimValue.low >>> 0, object.dimValue.high >>> 0).toNumber(); + if (object.dimParam != null) + message.dimParam = String(object.dimParam); + if (object.denotation != null) + message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.Dimension} message Dimension + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Dimension.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.denotation = ""; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + if (typeof message.dimValue === "number") + object.dimValue = options.longs === String ? String(message.dimValue) : message.dimValue; + else + object.dimValue = options.longs === String ? $util.Long.prototype.toString.call(message.dimValue) : options.longs === Number ? new $util.LongBits(message.dimValue.low >>> 0, message.dimValue.high >>> 0).toNumber() : message.dimValue; + if (options.oneofs) + object.value = "dimValue"; + } + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + object.dimParam = message.dimParam; + if (options.oneofs) + object.value = "dimParam"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + return object; + }; + + /** + * Converts this Dimension to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto.Dimension + * @instance + * @returns {Object.} JSON object + */ + Dimension.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Dimension + * @function getTypeUrl + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Dimension.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorShapeProto.Dimension"; + }; + + return Dimension; + })(); + + return TensorShapeProto; + })(); + + onnx.TypeProto = (function() { + + /** + * Properties of a TypeProto. + * @memberof onnx + * @interface ITypeProto + * @property {onnx.TypeProto.ITensor|null} [tensorType] TypeProto tensorType + * @property {onnx.TypeProto.ISequence|null} [sequenceType] TypeProto sequenceType + * @property {onnx.TypeProto.IMap|null} [mapType] TypeProto mapType + * @property {onnx.TypeProto.IOptional|null} [optionalType] TypeProto optionalType + * @property {onnx.TypeProto.ISparseTensor|null} [sparseTensorType] TypeProto sparseTensorType + * @property {string|null} [denotation] TypeProto denotation + */ + + /** + * Constructs a new TypeProto. + * @memberof onnx + * @classdesc Represents a TypeProto. + * @implements ITypeProto + * @constructor + * @param {onnx.ITypeProto=} [properties] Properties to set + */ + function TypeProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TypeProto tensorType. + * @member {onnx.TypeProto.ITensor|null|undefined} tensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.tensorType = null; + + /** + * TypeProto sequenceType. + * @member {onnx.TypeProto.ISequence|null|undefined} sequenceType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sequenceType = null; + + /** + * TypeProto mapType. + * @member {onnx.TypeProto.IMap|null|undefined} mapType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.mapType = null; + + /** + * TypeProto optionalType. + * @member {onnx.TypeProto.IOptional|null|undefined} optionalType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.optionalType = null; + + /** + * TypeProto sparseTensorType. + * @member {onnx.TypeProto.ISparseTensor|null|undefined} sparseTensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sparseTensorType = null; + + /** + * TypeProto denotation. + * @member {string} denotation + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.denotation = ""; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * TypeProto value. + * @member {"tensorType"|"sequenceType"|"mapType"|"optionalType"|"sparseTensorType"|undefined} value + * @memberof onnx.TypeProto + * @instance + */ + Object.defineProperty(TypeProto.prototype, "value", { + get: $util.oneOfGetter($oneOfFields = ["tensorType", "sequenceType", "mapType", "optionalType", "sparseTensorType"]), + set: $util.oneOfSetter($oneOfFields) + }); + + /** + * Creates a new TypeProto instance using the specified properties. + * @function create + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto=} [properties] Properties to set + * @returns {onnx.TypeProto} TypeProto instance + */ + TypeProto.create = function create(properties) { + return new TypeProto(properties); + }; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.tensorType != null && Object.hasOwnProperty.call(message, "tensorType")) + $root.onnx.TypeProto.Tensor.encode(message.tensorType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.sequenceType != null && Object.hasOwnProperty.call(message, "sequenceType")) + $root.onnx.TypeProto.Sequence.encode(message.sequenceType, writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + if (message.mapType != null && Object.hasOwnProperty.call(message, "mapType")) + $root.onnx.TypeProto.Map.encode(message.mapType, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.denotation); + if (message.sparseTensorType != null && Object.hasOwnProperty.call(message, "sparseTensorType")) + $root.onnx.TypeProto.SparseTensor.encode(message.sparseTensorType, writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); + if (message.optionalType != null && Object.hasOwnProperty.call(message, "optionalType")) + $root.onnx.TypeProto.Optional.encode(message.optionalType, writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorType = $root.onnx.TypeProto.Tensor.decode(reader, reader.uint32()); + break; + } + case 4: { + message.sequenceType = $root.onnx.TypeProto.Sequence.decode(reader, reader.uint32()); + break; + } + case 5: { + message.mapType = $root.onnx.TypeProto.Map.decode(reader, reader.uint32()); + break; + } + case 9: { + message.optionalType = $root.onnx.TypeProto.Optional.decode(reader, reader.uint32()); + break; + } + case 8: { + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.decode(reader, reader.uint32()); + break; + } + case 6: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TypeProto message. + * @function verify + * @memberof onnx.TypeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TypeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + var properties = {}; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + properties.value = 1; + { + var error = $root.onnx.TypeProto.Tensor.verify(message.tensorType); + if (error) + return "tensorType." + error; + } + } + if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Sequence.verify(message.sequenceType); + if (error) + return "sequenceType." + error; + } + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Map.verify(message.mapType); + if (error) + return "mapType." + error; + } + } + if (message.optionalType != null && message.hasOwnProperty("optionalType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Optional.verify(message.optionalType); + if (error) + return "optionalType." + error; + } + } + if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.SparseTensor.verify(message.sparseTensorType); + if (error) + return "sparseTensorType." + error; + } + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; + return null; + }; + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto} TypeProto + */ + TypeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto) + return object; + var message = new $root.onnx.TypeProto(); + if (object.tensorType != null) { + if (typeof object.tensorType !== "object") + throw TypeError(".onnx.TypeProto.tensorType: object expected"); + message.tensorType = $root.onnx.TypeProto.Tensor.fromObject(object.tensorType); + } + if (object.sequenceType != null) { + if (typeof object.sequenceType !== "object") + throw TypeError(".onnx.TypeProto.sequenceType: object expected"); + message.sequenceType = $root.onnx.TypeProto.Sequence.fromObject(object.sequenceType); + } + if (object.mapType != null) { + if (typeof object.mapType !== "object") + throw TypeError(".onnx.TypeProto.mapType: object expected"); + message.mapType = $root.onnx.TypeProto.Map.fromObject(object.mapType); + } + if (object.optionalType != null) { + if (typeof object.optionalType !== "object") + throw TypeError(".onnx.TypeProto.optionalType: object expected"); + message.optionalType = $root.onnx.TypeProto.Optional.fromObject(object.optionalType); + } + if (object.sparseTensorType != null) { + if (typeof object.sparseTensorType !== "object") + throw TypeError(".onnx.TypeProto.sparseTensorType: object expected"); + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.fromObject(object.sparseTensorType); + } + if (object.denotation != null) + message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto + * @static + * @param {onnx.TypeProto} message TypeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TypeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.denotation = ""; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + object.tensorType = $root.onnx.TypeProto.Tensor.toObject(message.tensorType, options); + if (options.oneofs) + object.value = "tensorType"; + } + if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { + object.sequenceType = $root.onnx.TypeProto.Sequence.toObject(message.sequenceType, options); + if (options.oneofs) + object.value = "sequenceType"; + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + object.mapType = $root.onnx.TypeProto.Map.toObject(message.mapType, options); + if (options.oneofs) + object.value = "mapType"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { + object.sparseTensorType = $root.onnx.TypeProto.SparseTensor.toObject(message.sparseTensorType, options); + if (options.oneofs) + object.value = "sparseTensorType"; + } + if (message.optionalType != null && message.hasOwnProperty("optionalType")) { + object.optionalType = $root.onnx.TypeProto.Optional.toObject(message.optionalType, options); + if (options.oneofs) + object.value = "optionalType"; + } + return object; + }; + + /** + * Converts this TypeProto to JSON. + * @function toJSON + * @memberof onnx.TypeProto + * @instance + * @returns {Object.} JSON object + */ + TypeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TypeProto + * @function getTypeUrl + * @memberof onnx.TypeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TypeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto"; + }; + + TypeProto.Tensor = (function() { + + /** + * Properties of a Tensor. + * @memberof onnx.TypeProto + * @interface ITensor + * @property {number|null} [elemType] Tensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] Tensor shape + */ + + /** + * Constructs a new Tensor. + * @memberof onnx.TypeProto + * @classdesc Represents a Tensor. + * @implements ITensor + * @constructor + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + */ + function Tensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Tensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.elemType = 0; + + /** + * Tensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.shape = null; + + /** + * Creates a new Tensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + * @returns {onnx.TypeProto.Tensor} Tensor instance + */ + Tensor.create = function create(properties) { + return new Tensor(properties); + }; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Tensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Tensor message. + * @function verify + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Tensor.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) + return "shape." + error; + } + return null; + }; + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Tensor} Tensor + */ + Tensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Tensor) + return object; + var message = new $root.onnx.TypeProto.Tensor(); + if (object.elemType != null) + message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.Tensor.shape: object expected"); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.Tensor} message Tensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Tensor.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this Tensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Tensor + * @instance + * @returns {Object.} JSON object + */ + Tensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Tensor + * @function getTypeUrl + * @memberof onnx.TypeProto.Tensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Tensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Tensor"; + }; + + return Tensor; + })(); + + TypeProto.Sequence = (function() { + + /** + * Properties of a Sequence. + * @memberof onnx.TypeProto + * @interface ISequence + * @property {onnx.ITypeProto|null} [elemType] Sequence elemType + */ + + /** + * Constructs a new Sequence. + * @memberof onnx.TypeProto + * @classdesc Represents a Sequence. + * @implements ISequence + * @constructor + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + */ + function Sequence(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Sequence elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Sequence + * @instance + */ + Sequence.prototype.elemType = null; + + /** + * Creates a new Sequence instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + * @returns {onnx.TypeProto.Sequence} Sequence instance + */ + Sequence.create = function create(properties) { + return new Sequence(properties); + }; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Sequence(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Sequence message. + * @function verify + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Sequence.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) + return "elemType." + error; + } + return null; + }; + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Sequence} Sequence + */ + Sequence.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Sequence) + return object; + var message = new $root.onnx.TypeProto.Sequence(); + if (object.elemType != null) { + if (typeof object.elemType !== "object") + throw TypeError(".onnx.TypeProto.Sequence.elemType: object expected"); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.Sequence} message Sequence + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Sequence.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.elemType = null; + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Sequence to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Sequence + * @instance + * @returns {Object.} JSON object + */ + Sequence.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Sequence + * @function getTypeUrl + * @memberof onnx.TypeProto.Sequence + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Sequence.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Sequence"; + }; + + return Sequence; + })(); + + TypeProto.Map = (function() { + + /** + * Properties of a Map. + * @memberof onnx.TypeProto + * @interface IMap + * @property {number|null} [keyType] Map keyType + * @property {onnx.ITypeProto|null} [valueType] Map valueType + */ + + /** + * Constructs a new Map. + * @memberof onnx.TypeProto + * @classdesc Represents a Map. + * @implements IMap + * @constructor + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + */ + function Map(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Map keyType. + * @member {number} keyType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.keyType = 0; + + /** + * Map valueType. + * @member {onnx.ITypeProto|null|undefined} valueType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.valueType = null; + + /** + * Creates a new Map instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + * @returns {onnx.TypeProto.Map} Map instance + */ + Map.create = function create(properties) { + return new Map(properties); + }; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.keyType != null && Object.hasOwnProperty.call(message, "keyType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.keyType); + if (message.valueType != null && Object.hasOwnProperty.call(message, "valueType")) + $root.onnx.TypeProto.encode(message.valueType, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Map message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Map(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.keyType = reader.int32(); + break; + } + case 2: { + message.valueType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Map message. + * @function verify + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Map.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.keyType != null && message.hasOwnProperty("keyType")) + if (!$util.isInteger(message.keyType)) + return "keyType: integer expected"; + if (message.valueType != null && message.hasOwnProperty("valueType")) { + var error = $root.onnx.TypeProto.verify(message.valueType); + if (error) + return "valueType." + error; + } + return null; + }; + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Map} Map + */ + Map.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Map) + return object; + var message = new $root.onnx.TypeProto.Map(); + if (object.keyType != null) + message.keyType = object.keyType | 0; + if (object.valueType != null) { + if (typeof object.valueType !== "object") + throw TypeError(".onnx.TypeProto.Map.valueType: object expected"); + message.valueType = $root.onnx.TypeProto.fromObject(object.valueType); + } + return message; + }; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.Map} message Map + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Map.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.keyType = 0; + object.valueType = null; + } + if (message.keyType != null && message.hasOwnProperty("keyType")) + object.keyType = message.keyType; + if (message.valueType != null && message.hasOwnProperty("valueType")) + object.valueType = $root.onnx.TypeProto.toObject(message.valueType, options); + return object; + }; + + /** + * Converts this Map to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Map + * @instance + * @returns {Object.} JSON object + */ + Map.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Map + * @function getTypeUrl + * @memberof onnx.TypeProto.Map + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Map.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Map"; + }; + + return Map; + })(); + + TypeProto.Optional = (function() { + + /** + * Properties of an Optional. + * @memberof onnx.TypeProto + * @interface IOptional + * @property {onnx.ITypeProto|null} [elemType] Optional elemType + */ + + /** + * Constructs a new Optional. + * @memberof onnx.TypeProto + * @classdesc Represents an Optional. + * @implements IOptional + * @constructor + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + */ + function Optional(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Optional elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Optional + * @instance + */ + Optional.prototype.elemType = null; + + /** + * Creates a new Optional instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + * @returns {onnx.TypeProto.Optional} Optional instance + */ + Optional.create = function create(properties) { + return new Optional(properties); + }; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Optional(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an Optional message. + * @function verify + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Optional.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) + return "elemType." + error; + } + return null; + }; + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Optional} Optional + */ + Optional.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Optional) + return object; + var message = new $root.onnx.TypeProto.Optional(); + if (object.elemType != null) { + if (typeof object.elemType !== "object") + throw TypeError(".onnx.TypeProto.Optional.elemType: object expected"); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.Optional} message Optional + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Optional.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.elemType = null; + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Optional to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Optional + * @instance + * @returns {Object.} JSON object + */ + Optional.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Optional + * @function getTypeUrl + * @memberof onnx.TypeProto.Optional + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Optional.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Optional"; + }; + + return Optional; + })(); + + TypeProto.SparseTensor = (function() { + + /** + * Properties of a SparseTensor. + * @memberof onnx.TypeProto + * @interface ISparseTensor + * @property {number|null} [elemType] SparseTensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] SparseTensor shape + */ + + /** + * Constructs a new SparseTensor. + * @memberof onnx.TypeProto + * @classdesc Represents a SparseTensor. + * @implements ISparseTensor + * @constructor + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + */ + function SparseTensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.elemType = 0; + + /** + * SparseTensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.shape = null; + + /** + * Creates a new SparseTensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + * @returns {onnx.TypeProto.SparseTensor} SparseTensor instance + */ + SparseTensor.create = function create(properties) { + return new SparseTensor(properties); + }; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.SparseTensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensor message. + * @function verify + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensor.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) + return "shape." + error; + } + return null; + }; + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + */ + SparseTensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.SparseTensor) + return object; + var message = new $root.onnx.TypeProto.SparseTensor(); + if (object.elemType != null) + message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.SparseTensor.shape: object expected"); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.SparseTensor} message SparseTensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensor.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this SparseTensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.SparseTensor + * @instance + * @returns {Object.} JSON object + */ + SparseTensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensor + * @function getTypeUrl + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.SparseTensor"; + }; + + return SparseTensor; + })(); + + return TypeProto; + })(); + + onnx.OperatorSetIdProto = (function() { + + /** + * Properties of an OperatorSetIdProto. + * @memberof onnx + * @interface IOperatorSetIdProto + * @property {string|null} [domain] OperatorSetIdProto domain + * @property {number|Long|null} [version] OperatorSetIdProto version + */ + + /** + * Constructs a new OperatorSetIdProto. + * @memberof onnx + * @classdesc Represents an OperatorSetIdProto. + * @implements IOperatorSetIdProto + * @constructor + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + */ + function OperatorSetIdProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * OperatorSetIdProto domain. + * @member {string} domain + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.domain = ""; + + /** + * OperatorSetIdProto version. + * @member {number|Long} version + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.version = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @function create + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto instance + */ + OperatorSetIdProto.create = function create(properties) { + return new OperatorSetIdProto(properties); + }; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.domain); + if (message.version != null && Object.hasOwnProperty.call(message, "version")) + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.version); + return writer; + }; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.OperatorSetIdProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.domain = reader.string(); + break; + } + case 2: { + message.version = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an OperatorSetIdProto message. + * @function verify + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + OperatorSetIdProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.version != null && message.hasOwnProperty("version")) + if (!$util.isInteger(message.version) && !(message.version && $util.isInteger(message.version.low) && $util.isInteger(message.version.high))) + return "version: integer|Long expected"; + return null; + }; + + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + */ + OperatorSetIdProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.OperatorSetIdProto) + return object; + var message = new $root.onnx.OperatorSetIdProto(); + if (object.domain != null) + message.domain = String(object.domain); + if (object.version != null) + if ($util.Long) + (message.version = $util.Long.fromValue(object.version)).unsigned = false; + else if (typeof object.version === "string") + message.version = parseInt(object.version, 10); + else if (typeof object.version === "number") + message.version = object.version; + else if (typeof object.version === "object") + message.version = new $util.LongBits(object.version.low >>> 0, object.version.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.OperatorSetIdProto} message OperatorSetIdProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + OperatorSetIdProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.domain = ""; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.version = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.version = options.longs === String ? "0" : 0; + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.version != null && message.hasOwnProperty("version")) + if (typeof message.version === "number") + object.version = options.longs === String ? String(message.version) : message.version; + else + object.version = options.longs === String ? $util.Long.prototype.toString.call(message.version) : options.longs === Number ? new $util.LongBits(message.version.low >>> 0, message.version.high >>> 0).toNumber() : message.version; + return object; + }; + + /** + * Converts this OperatorSetIdProto to JSON. + * @function toJSON + * @memberof onnx.OperatorSetIdProto + * @instance + * @returns {Object.} JSON object + */ + OperatorSetIdProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for OperatorSetIdProto + * @function getTypeUrl + * @memberof onnx.OperatorSetIdProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + OperatorSetIdProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.OperatorSetIdProto"; + }; + + return OperatorSetIdProto; + })(); + + /** + * OperatorStatus enum. + * @name onnx.OperatorStatus + * @enum {number} + * @property {number} EXPERIMENTAL=0 EXPERIMENTAL value + * @property {number} STABLE=1 STABLE value + */ + onnx.OperatorStatus = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "EXPERIMENTAL"] = 0; + values[valuesById[1] = "STABLE"] = 1; + return values; + })(); + + onnx.FunctionProto = (function() { + + /** + * Properties of a FunctionProto. + * @memberof onnx + * @interface IFunctionProto + * @property {string|null} [name] FunctionProto name + * @property {Array.|null} [input] FunctionProto input + * @property {Array.|null} [output] FunctionProto output + * @property {Array.|null} [attribute] FunctionProto attribute + * @property {Array.|null} [attributeProto] FunctionProto attributeProto + * @property {Array.|null} [node] FunctionProto node + * @property {string|null} [docString] FunctionProto docString + * @property {Array.|null} [opsetImport] FunctionProto opsetImport + * @property {string|null} [domain] FunctionProto domain + */ + + /** + * Constructs a new FunctionProto. + * @memberof onnx + * @classdesc Represents a FunctionProto. + * @implements IFunctionProto + * @constructor + * @param {onnx.IFunctionProto=} [properties] Properties to set + */ + function FunctionProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + this.attributeProto = []; + this.node = []; + this.opsetImport = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * FunctionProto name. + * @member {string} name + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.name = ""; + + /** + * FunctionProto input. + * @member {Array.} input + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.input = $util.emptyArray; + + /** + * FunctionProto output. + * @member {Array.} output + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.output = $util.emptyArray; + + /** + * FunctionProto attribute. + * @member {Array.} attribute + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attribute = $util.emptyArray; + + /** + * FunctionProto attributeProto. + * @member {Array.} attributeProto + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attributeProto = $util.emptyArray; + + /** + * FunctionProto node. + * @member {Array.} node + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.node = $util.emptyArray; + + /** + * FunctionProto docString. + * @member {string} docString + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.docString = ""; + + /** + * FunctionProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.opsetImport = $util.emptyArray; + + /** + * FunctionProto domain. + * @member {string} domain + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.domain = ""; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @function create + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto=} [properties] Properties to set + * @returns {onnx.FunctionProto} FunctionProto instance + */ + FunctionProto.create = function create(properties) { + return new FunctionProto(properties); + }; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encode + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 5, wireType 2 =*/42).string(message.output[i]); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.attribute[i]); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 8, wireType 2 =*/66).string(message.docString); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 10, wireType 2 =*/82).string(message.domain); + if (message.attributeProto != null && message.attributeProto.length) + for (var i = 0; i < message.attributeProto.length; ++i) + $root.onnx.AttributeProto.encode(message.attributeProto[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.FunctionProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 4: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push(reader.string()); + break; + } + case 5: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push(reader.string()); + break; + } + case 6: { + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push(reader.string()); + break; + } + case 11: { + if (!(message.attributeProto && message.attributeProto.length)) + message.attributeProto = []; + message.attributeProto.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 7: { + if (!(message.node && message.node.length)) + message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 8: { + message.docString = reader.string(); + break; + } + case 9: { + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.domain = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a FunctionProto message. + * @function verify + * @memberof onnx.FunctionProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + FunctionProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) + return "input: string[] expected"; + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; + for (var i = 0; i < message.attribute.length; ++i) + if (!$util.isString(message.attribute[i])) + return "attribute: string[] expected"; + } + if (message.attributeProto != null && message.hasOwnProperty("attributeProto")) { + if (!Array.isArray(message.attributeProto)) + return "attributeProto: array expected"; + for (var i = 0; i < message.attributeProto.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attributeProto[i]); + if (error) + return "attributeProto." + error; + } + } + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) + return "node: array expected"; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) + return "node." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) + return "opsetImport." + error; + } + } + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + return null; + }; + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.FunctionProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.FunctionProto} FunctionProto + */ + FunctionProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.FunctionProto) + return object; + var message = new $root.onnx.FunctionProto(); + if (object.name != null) + message.name = String(object.name); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.FunctionProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.FunctionProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); + } + if (object.attribute) { + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.FunctionProto.attribute: array expected"); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) + message.attribute[i] = String(object.attribute[i]); + } + if (object.attributeProto) { + if (!Array.isArray(object.attributeProto)) + throw TypeError(".onnx.FunctionProto.attributeProto: array expected"); + message.attributeProto = []; + for (var i = 0; i < object.attributeProto.length; ++i) { + if (typeof object.attributeProto[i] !== "object") + throw TypeError(".onnx.FunctionProto.attributeProto: object expected"); + message.attributeProto[i] = $root.onnx.AttributeProto.fromObject(object.attributeProto[i]); + } + } + if (object.node) { + if (!Array.isArray(object.node)) + throw TypeError(".onnx.FunctionProto.node: array expected"); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.FunctionProto.node: object expected"); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.FunctionProto.opsetImport: array expected"); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.FunctionProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.domain != null) + message.domain = String(object.domain); + return message; + }; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.FunctionProto + * @static + * @param {onnx.FunctionProto} message FunctionProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + FunctionProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + object.node = []; + object.opsetImport = []; + object.attributeProto = []; + } + if (options.defaults) { + object.name = ""; + object.docString = ""; + object.domain = ""; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; + } + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = message.attribute[j]; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.attributeProto && message.attributeProto.length) { + object.attributeProto = []; + for (var j = 0; j < message.attributeProto.length; ++j) + object.attributeProto[j] = $root.onnx.AttributeProto.toObject(message.attributeProto[j], options); + } + return object; + }; + + /** + * Converts this FunctionProto to JSON. + * @function toJSON + * @memberof onnx.FunctionProto + * @instance + * @returns {Object.} JSON object + */ + FunctionProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for FunctionProto + * @function getTypeUrl + * @memberof onnx.FunctionProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + FunctionProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.FunctionProto"; + }; + + return FunctionProto; + })(); + + return onnx; +})(); + +module.exports = $root; diff --git a/js/node/test/test-utils.ts b/js/node/test/test-utils.ts index 968e8a1881810..3eef90356a335 100644 --- a/js/node/test/test-utils.ts +++ b/js/node/test/test-utils.ts @@ -4,10 +4,11 @@ import assert from 'assert'; import * as fs from 'fs-extra'; import {jsonc} from 'jsonc'; -import * as onnx_proto from 'onnx-proto'; import {InferenceSession, Tensor} from 'onnxruntime-common'; import * as path from 'path'; +import * as onnx_proto from './ort-schema/protobuf/onnx'; + export const TEST_ROOT = __dirname; export const TEST_DATA_ROOT = path.join(TEST_ROOT, 'testdata'); diff --git a/js/package-lock.json b/js/package-lock.json index c87a58a3196d6..c16a8b59a3a6f 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -3391,9 +3391,9 @@ } }, "node_modules/normalize-package-data/node_modules/semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true, "bin": { "semver": "bin/semver" @@ -7011,9 +7011,9 @@ }, "dependencies": { "semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true } } diff --git a/js/react_native/app.plugin.js b/js/react_native/app.plugin.js index bce476e9e9657..ed4cfe48563bd 100644 --- a/js/react_native/app.plugin.js +++ b/js/react_native/app.plugin.js @@ -29,7 +29,7 @@ const withOrt = (config) => { config = configPlugin.withDangerousMod(config, [ 'ios', (config) => { - const podFilePath = path.join(config.modRequest.platformProjectRoot, 'PodFile'); + const podFilePath = path.join(config.modRequest.platformProjectRoot, 'Podfile'); const contents = fs.readFileSync(podFilePath, {encoding: 'utf-8'}); const updatedContents = generateCode diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 0b82a9c031baa..2f510308d9306 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -20,7 +20,9 @@ Do not modify directly.* | Asinh | ai.onnx(9+) | | | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | +| Attention | com.microsoft(1+) | need implementing mask and past/present | | AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation | +| BatchNormalization | ai.onnx(7-8,9-13,14,15+); com.ms.internal.nhwc(7-8,9-13,14,15+) | | | BiasAdd | com.microsoft(1+) | | | BiasSplitGelu | com.microsoft(1+) | | | Cast | ai.onnx(6-8,9-12,13-18,19+) | | @@ -31,6 +33,7 @@ Do not modify directly.* | ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(1-10,11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation | | Cos | ai.onnx(7+) | | | Cosh | ai.onnx(9+) | | +| CumSum | ai.onnx(11-13,14+) | | | Div | ai.onnx(7-12,13,14+) | | | Einsum | ai.onnx(12+) | | | Elu | ai.onnx(6+) | | @@ -61,6 +64,7 @@ Do not modify directly.* | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | | Mul | ai.onnx(7-12,13,14+) | | +| MultiHeadAttention | com.microsoft(1+) | need implementing mask and past/present | | Neg | ai.onnx(6-12,13+) | | | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | diff --git a/js/web/lib/onnxjs/attribute-with-cache-key.ts b/js/web/lib/onnxjs/attribute-with-cache-key.ts index 6608b00471e77..5d47570f267a6 100644 --- a/js/web/lib/onnxjs/attribute-with-cache-key.ts +++ b/js/web/lib/onnxjs/attribute-with-cache-key.ts @@ -6,13 +6,13 @@ class AttributeWithCacheKeyImpl { Object.assign(this, attribute); } - private _cacheKey: string; + private key: string; public get cacheKey(): string { - if (!this._cacheKey) { - this._cacheKey = + if (!this.key) { + this.key = Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); } - return this._cacheKey; + return this.key; } } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index e2c2bc8deccf4..4f4a06c37a94f 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -254,11 +254,9 @@ export class WebGpuBackend { } isQueryEnabled(): boolean { - if (this.device.features.has('timestamp-query') && this.env.webgpu.profilingMode === 'default') { - return true; - } else { - return false; - } + return this.device.features.has('timestamp-query') && + (this.env.webgpu.profiling?.mode === 'default' || + (!this.env.webgpu.profiling?.mode && this.env.webgpu.profilingMode === 'default')); } /** @@ -338,51 +336,26 @@ export class WebGpuBackend { let uniformBufferBinding: GPUBindingResource|undefined; if (programUniforms) { let currentOffset = 0; - let preLength = 0; const offsets: number[] = []; - let maxAlignmentOfField = 1; + programUniforms.forEach(v => { const data = typeof v.data === 'number' ? [v.data] : v.data; if (data.length === 0) { return; } // https://www.w3.org/TR/WGSL/#alignof - let baseAlignment: number; - switch (data.length) { - case 1: - baseAlignment = 4; - break; - case 2: - baseAlignment = 8; - break; - case 3: - baseAlignment = 16; - break; - case 4: - baseAlignment = 16; - break; - case 5: - baseAlignment = 16; - break; - case 6: - baseAlignment = 16; - break; - default: - throw new Error(`unsupported data length: ${data.length}`); - } - - if (preLength === 5 || preLength === 6) { - baseAlignment = 16; - } - if (baseAlignment > maxAlignmentOfField) { - maxAlignmentOfField = baseAlignment; - } + const baseAlignment = data.length <= 2 ? data.length * 4 : 16; currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment; - preLength = data.length; offsets.push(currentOffset); - currentOffset += data.length * 4; + // When data.length > 4, the uniform variable is of type array,N>, where N = + // Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * + // SizeOf(vec4). + currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4; }); + // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set + // maxAlignmentOfField to 16 since the underlying buffer has been rounded up to 16. + const maxAlignmentOfField = 16; currentOffset = Math.ceil(currentOffset / maxAlignmentOfField) * maxAlignmentOfField; const arrayBuffer = new ArrayBuffer(currentOffset); programUniforms.forEach((v, i) => { @@ -413,6 +386,7 @@ export class WebGpuBackend { if (!artifact) { artifact = this.programManager.build(program, normalizedDispatchGroup); this.programManager.setArtifact(key, artifact); + LOG_DEBUG('info', () => `[artifact] key: ${key}, programName: ${program.name}`); } LOG_DEBUG( diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index d66357e729d5d..e6db631c44eea 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -175,8 +175,7 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => { // jsepCreateKernel (name: string, kernel: number, attribute: unknown) => backend.createKernel( name, kernel, attribute, - env.debug || env.webgpu.profilingMode === 'default' ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : - `${kernel}`), + env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`), // jsepReleaseKernel (kernel: number) => backend.releaseKernel(kernel), diff --git a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts index adba0fb9d022d..ad56b92c1d869 100644 --- a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts +++ b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts @@ -6,13 +6,13 @@ class AttributeWithCacheKeyImpl { Object.assign(this, attribute); } - private _cacheKey: string; + private key: string; public get cacheKey(): string { - if (!this._cacheKey) { - this._cacheKey = + if (!this.key) { + this.key = Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); } - return this._cacheKey; + return this.key; } } diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index a4d51e68b6a25..8e1ec782079be 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -2,12 +2,15 @@ // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; +import {attention, parseAttentionAttributes} from './ops/attention'; +import {batchNorm} from './ops/batch-norm'; import {biasAdd} from './ops/bias-add'; import {biasSplitGelu} from './ops/bias-split-gelu'; import * as binaryOps from './ops/binary-op'; import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'; +import {cumsum, parseCumSumAttributes} from './ops/cumsum'; import {einsum, parseEinsumAttributes} from './ops/einsum'; import {expand} from './ops/expand'; import {gather, parseGatherAttributes} from './ops/gather'; @@ -16,10 +19,11 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad, parsePadAttributes} from './ops/pad'; import * as pool from './ops/pool'; import {range} from './ops/range'; -import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; +import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; import {parseSliceAttributes, slice} from './ops/slice'; @@ -46,19 +50,21 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Asinh', [unaryOps.asinh]], ['Atan', [unaryOps.atan]], ['Atanh', [unaryOps.atanh]], + ['Attention', [attention, parseAttentionAttributes]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], + ['BatchNormalization', [batchNorm]], ['BiasAdd', [biasAdd]], ['BiasSplitGelu', [biasSplitGelu]], ['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]], ['Ceil', [unaryOps.ceil]], - ['ClipV10', [unaryOps.clipV10]], ['Clip', [unaryOps.clip]], ['Concat', [concat, parseConcatAttributes]], ['Conv', [conv, parseConvAttributes]], ['ConvTranspose', [convTranspose, parseConvTransposeAttributes]], ['Cos', [unaryOps.cos]], ['Cosh', [unaryOps.cosh]], + ['CumSum', [cumsum, parseCumSumAttributes]], ['Div', [binaryOps.div]], ['Einsum', [einsum, parseEinsumAttributes]], ['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]], @@ -86,22 +92,23 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], + ['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]], ['Neg', [unaryOps.neg]], ['Not', [unaryOps.not]], ['Pad', [pad, parsePadAttributes]], ['Pow', [binaryOps.pow]], ['Range', [range]], ['Reciprocal', [unaryOps.reciprocal]], - ['ReduceMin', [reduceMin, parseReduceAttributes]], - ['ReduceMean', [reduceMean, parseReduceAttributes]], - ['ReduceMax', [reduceMax, parseReduceAttributes]], - ['ReduceSum', [reduceSum, parseReduceAttributes]], - ['ReduceProd', [reduceProd, parseReduceAttributes]], - ['ReduceL1', [reduceL1, parseReduceAttributes]], - ['ReduceL2', [reduceL2, parseReduceAttributes]], - ['ReduceLogSum', [reduceLogSum, parseReduceAttributes]], - ['ReduceLogSumExp', [reduceLogSumExp, parseReduceAttributes]], - ['ReduceSumSquare', [reduceSumSquare, parseReduceAttributes]], + ['ReduceMin', [reduceMin]], + ['ReduceMean', [reduceMean]], + ['ReduceMax', [reduceMax]], + ['ReduceSum', [reduceSum]], + ['ReduceProd', [reduceProd]], + ['ReduceL1', [reduceL1]], + ['ReduceL2', [reduceL2]], + ['ReduceLogSum', [reduceLogSum]], + ['ReduceLogSumExp', [reduceLogSumExp]], + ['ReduceSumSquare', [reduceSumSquare]], ['Relu', [unaryOps.relu]], ['Resize', [resize, parseResizeAttributes]], ['Sigmoid', [unaryOps.sigmoid]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 089e783d7e22f..3638938df7dbe 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -21,9 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; -import {tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {ConvAttributes} from '../conv'; import {getActivationSnippet} from '../fuse-utils'; @@ -50,9 +49,9 @@ const conv2dCommonSnippet = const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: - return 'return w[row * wShape[3] + colIn];'; + return 'return w[row * i32(uniforms.w_shape[3]) + colIn];'; case 4: - return 'return w[row * wShape[3] / 4 + colIn];'; + return 'return w[row * i32(uniforms.w_shape[3]) / 4 + colIn];'; default: throw new Error(`innerElementSize ${innerElementSize} is not supported.`); } @@ -79,13 +78,13 @@ const conv2dCommonSnippet = col % outWidth); `; - const xHeight = isChannelsLast ? 'xShape[1]' : 'xShape[2]'; - const xWidth = isChannelsLast ? 'xShape[2]' : 'xShape[3]'; + const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; + const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; const row = isChannelsLast ? 'row' : 'col'; const col = isChannelsLast ? 'col' : 'row'; const readXSnippet = ` - let inChannels = wShape[2]; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let inChannels = i32(uniforms.w_shape[2]); + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; @@ -99,7 +98,7 @@ const conv2dCommonSnippet = // the 'same' padding type. if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) { ${coordASnippet} - let xIndex = getIndexFromCoords4D(coord, xShape); + let xIndex = getIndexFromCoords4D(coord, vec4(uniforms.x_shape)); ${getXSnippet(innerElementSizeX)} } return resData;`; @@ -109,7 +108,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < dimAOuter && col < dimInner) { + if (row < uniforms.dimAOuter && col < uniforms.dimInner) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : @@ -118,7 +117,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < dimInner && col < dimBOuter) { + if (row < uniforms.dimInner && col < uniforms.dimBOuter) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); @@ -143,10 +142,10 @@ const conv2dCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) { let col = colIn * ${innerElementSize}; - if (row < dimAOuter && col < dimBOuter) + if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { var value = valueIn; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; ${coordResSnippet} ${biasSnippet(addBias)} ${applyActivation} @@ -181,7 +180,7 @@ export const createConv2DMatMulProgramInfo = LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); - const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : elementsPerThread[0]; + const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; const tileAOuter = workGroupSize[1] * elementsPerThread[1]; const tileBOuter = workGroupSize[0] * elementsPerThread[0]; @@ -194,10 +193,18 @@ export const createConv2DMatMulProgramInfo = const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; const t = tensorTypeToWsglStorageType(inputs[0].dataType); - const declareInputs = [ - `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? `vec4<${t}>` : t}>;`, - `@group(0) @binding(1) var w: array<${isVec4 ? `vec4<${t}>` : t}>;` - ]; + // TODO: support component 2, 3. + const components = isVec4 ? 4 : 1; + const programUniforms: ProgramUniform[] = + [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + const x = + inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); + const inputVariables = [x, w]; + + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + let declareFunctions = ` fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); @@ -207,41 +214,40 @@ export const createConv2DMatMulProgramInfo = setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? `vec4<${t}>` : t}>;`); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } - + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + programUniforms.push(...createTensorShapeVariables(outputShape)); return { name: 'Conv2DMatMul', shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms, }), - getShaderSource: () => ` - ${utilFunctions} + getShaderSource: (shaderHelper: ShaderHelper) => ` + ${utilFunctions('uniforms.result_strides')} //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; - ${declareInputs.join('')} - @group(0) @binding(${declareInputs.length}) var result: array<${ - isVec4 ? `vec4<${t}>` : t}>; - //@group(0) @binding(${declareInputs.length + 1}) var uniforms: Uniforms; - - const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); - const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); + ${ + shaderHelper.registerUniform('dimAOuter', 'i32') + .registerUniform('dimBOuter', 'i32') + .registerUniform('dimInner', 'i32') + .declareVariables(...inputVariables, output)} const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const dimAOuter : i32 = ${dimAOuter}; - const dimBOuter : i32 = ${dimBOuter}; - const dimInner : i32 = ${dimInner}; ${declareFunctions} ${ conv2dCommonSnippet( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 85cf7bf87f52c..d425155857e14 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; +import {ProgramInfo, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {getActivationSnippet} from '../fuse-utils'; @@ -36,16 +36,16 @@ const conv2dTransposeCommonSnippet = const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: - return 'return W[getIndexFromCoords4D(coord, wShape)];'; + return 'return w[getIndexFromCoords4D(coord, vec4(uniforms.w_shape))];'; case 4: return ` let coord1 = vec4(coordX, coordY, col + 1, rowInner); let coord2 = vec4(coordX, coordY, col + 2, rowInner); let coord3 = vec4(coordX, coordY, col + 3, rowInner); - let v0 = W[getIndexFromCoords4D(coord, wShape)]; - let v1 = W[getIndexFromCoords4D(coord1, wShape)]; - let v2 = W[getIndexFromCoords4D(coord2, wShape)]; - let v3 = W[getIndexFromCoords4D(coord3, wShape)]; + let v0 = w[getIndexFromCoords4D(coord, vec4(uniforms.w_shape))]; + let v1 = w[getIndexFromCoords4D(coord1, vec4(uniforms.w_shape))]; + let v2 = w[getIndexFromCoords4D(coord2, vec4(uniforms.w_shape))]; + let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))]; return vec4(v0, v1, v2, v3); `; default: @@ -81,7 +81,7 @@ const conv2dTransposeCommonSnippet = const readASnippet = ` let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; @@ -99,17 +99,17 @@ const conv2dTransposeCommonSnippet = let iXC = i32(xC); let xCh = ${col} % inChannels; ${coordASnippet} - return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`; + return x[getIndexFromCoords4D(coord, vec4(uniforms.x_shape))/${innerElementSize}];`; const sampleA = isChannelsLast ? ` let col = colIn * ${innerElementSize}; - if (row < dimAOuter && col < dimInner) { + if (row < uniforms.dimAOuter && col < uniforms.dimInner) { ${readASnippet} } return ${type}(0.0);` : ` let col = colIn * ${innerElementSize}; - if (row < dimInner && col < dimBOuter) { + if (row < uniforms.dimInner && col < uniforms.dimBOuter) { ${readASnippet} } return ${type}(0.0);`; @@ -120,8 +120,8 @@ const conv2dTransposeCommonSnippet = let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; if (${ - isChannelsLast ? 'row < dimInner && col < dimBOuter' : - 'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) { + isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' : + 'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) { let rowInner = row % inChannels; let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} @@ -142,13 +142,13 @@ const conv2dTransposeCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { let col = colIn * ${innerElementSize}; - if (row < dimAOuter && col < dimBOuter) { + if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { var value = valueInput; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; ${coordResSnippet} ${biasSnippet(addBias)} ${applyActivation} - result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value; + result[getIndexFromCoords4D(coords, vec4(uniforms.result_shape))/${innerElementSize}] = value; } }`; return userCode; @@ -185,37 +185,46 @@ export const createConv2DTransposeMatMulProgramInfo = const innerElementSize = isVec4 ? 4 : 1; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + const components = isVec4 ? 4 : 1; + const programUniforms: ProgramUniform[] = + [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const inputVariables = [x, w]; + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); - - const declareInputs = [ - `@group(0) @binding(0) var x: array<${isVec4 ? 'vec4' : 'f32'}>;`, - '@group(0) @binding(1) var W: array;' - ]; let declareFunctions = ''; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } + + programUniforms.push(...createTensorShapeVariables(outputShape)); + return { name: 'Conv2DTransposeMatMul', shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]} + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms }), - getShaderSource: () => ` - ${utilFunctions} - ${declareInputs.join('\n')} - @group(0) @binding(${declareInputs.length}) var result: array<${ - isVec4 ? 'vec4' : 'f32'}>; + getShaderSource: (shaderHelper: ShaderHelper) => ` + ${utilFunctions('uniforms.result_strides')} + ${ + shaderHelper.registerUniform('dimAOuter', 'i32') + .registerUniform('dimBOuter', 'i32') + .registerUniform('dimInner', 'i32') + .declareVariables(...inputVariables, output)}; const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); - const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ attributes.kernelShape[isChannelsLast ? 2 : 3]}); const effectiveFilterDims : vec2 = filterDims + vec2( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts index 0ba48a33fbc47..6f2c0231104dc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts @@ -19,13 +19,13 @@ // // modified to fit the needs of the project -export const utilFunctions = ` +export const utilFunctions = (strideStr: string) => (` fn getIndexFromCoords4D(coords : vec4, shape : vec4) -> i32 { return dot(coords, vec4( shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1)); } fn getOutputIndexFromCoords(coords : vec4) -> i32 { return dot(coords, vec4( - outShapeStrides.x, outShapeStrides.y, outShapeStrides.z, 1)); + i32(${strideStr}.x), i32(${strideStr}.y), i32(${strideStr}.z), 1)); } -`; +`); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 335de01c596b7..47ec16a296712 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -21,8 +21,8 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; -import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -112,7 +112,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; + let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc: array, rowPerThread>; @@ -322,7 +322,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, @builtin(workgroup_id) workgroupId : vec3) { let batch = ${splitK ? '0' : 'i32(globalId.z)'}; ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; + let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc : array, rowPerThread>; @@ -341,13 +341,8 @@ fn main(@builtin(local_invocation_id) localId : vec3, const matMulReadWriteFnSource = (component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[], batchShapes: Array, isChannelsLast = false): string => { - const batchAShape = batchShapes[0]; - const batchBShape = batchShapes[1]; - const batchShape = batchShapes[2]; - const batchVariable = variables[0]; - const aVariable = variables[1]; - const bVariable = variables[2]; - const outputVariable = variables[3]; + const [batchAShape, batchBShape, batchShape] = batchShapes; + const [batchVariable, aVariable, bVariable, outputVariable] = variables; const broadCastADims = getBroadcastDims(batchAShape, batchShape); const broadCastBDims = getBroadcastDims(batchBShape, batchShape); const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); @@ -384,7 +379,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < dimAOuter && col < dimInner) + if(row < uniforms.dimAOuter && col < uniforms.dimInner) { ${getAIndices()} value = ${aVariable.getByIndices('aIndices')}; @@ -396,7 +391,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < dimInner && col < dimBOuter) + if(row < uniforms.dimInner && col < uniforms.dimBOuter) { ${getBIndices()} value = ${bVariable.getByIndices('bIndices')}; @@ -406,7 +401,7 @@ const matMulReadWriteFnSource = fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; - if (row < dimAOuter && col < dimBOuter) { + if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { var value = valueIn; let coords = vec3(batch, row, colIn); ${ @@ -430,10 +425,11 @@ export const createMatmulProgramInfo = const outerDimsA = aShape.slice(0, -2); const outerDimsB = bShape.slice(0, -2); + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const batchDims = inputVariable('batchDims', inputs[0].dataType, outerDims); - const variables = [batchDims]; - const batchShapes = [outerDimsA, outerDimsB, outerDims]; + const enableBatchUniforms = enableShapesUniforms(outerDims.length); + const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); const batchSize = ShapeUtil.size(outerDims); const dimAOuter = aShape[aShape.length - 2]; @@ -452,39 +448,76 @@ export const createMatmulProgramInfo = const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; - const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components); - const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components); - const output = - outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter / components], components); - variables.push(A); - variables.push(B); - variables.push(output); + + const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; + const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length); + const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp; + + const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; + const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length); + const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp; + + const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; + + const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); + const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); + const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); const inputVariables = [A, B]; + const programUniforms: ProgramUniform[] = + [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + if (enableBatchUniforms) { + programUniforms.push(...createTensorShapeVariables(outerDims)); + } + if (enableAShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(aShapeTemp)); + } + if (enableBShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(bShapeTemp)); + } + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims'); + inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims'); + const hasBias = inputs.length > 2; const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); - const declareFunctions = - matMulReadWriteFnSource(components, hasBias, applyActivation, variables, batchShapes, isChannelsLast); + const declareFunctions = matMulReadWriteFnSource( + components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], + isChannelsLast); if (hasBias) { const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims, biasComponents)); + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + + inputDependencies.push('rank'); } + programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); + const getShaderSource = (shaderHelper: ShaderHelper) => ` - const dimAOuter: i32 = ${dimAOuter}; - const dimBOuter: i32 = ${dimBOuter}; - const dimInner: i32 = ${dimInner}; - ${shaderHelper.declareVariables(...inputVariables, output)} + ${ + shaderHelper.registerUniform('dimAOuter', 'i32') + .registerUniform('dimBOuter', 'i32') + .registerUniform('dimInner', 'i32') + .registerInternalVariables(batchDims) + .declareVariables(...inputVariables, output)} ${activationFunction} ${declareFunctions} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} - ${batchDims.impl()}`; + `; + // TODO: turn clipMax and clipMin to uniforms. return { name: 'MatMul', - shaderCache: {hint: activationAttributes.activationCacheKey}, + shaderCache: { + hint: activationAttributes.activationCacheKey + `${elementsPerThread}` + + `${isVec4}` + + `${isChannelsLast}`, + inputDependencies + }, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]} + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts index b6c6853c8f222..1f27525f370f3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts @@ -33,23 +33,23 @@ export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes) const idxZero = []; for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(`inputIndices[${k}] = 0;`); // first element + idxZero.push(`input_indices[${k}] = 0;`); // first element } } return [ - `${idxZero.join('\n')}`, `var value = ${input.getByOffset('inputOffset')};\nvar bestIndex : i32 = 0;`, - `if (${input.getByOffset('inputOffset')} ${attributes.selectLastIndex > 0 ? '<=' : '<'} value) { - value = ${input.getByOffset('inputOffset')}; - bestIndex = i32(lastIndex); + `${idxZero.join('\n')}`, `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`, + `if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '<=' : '<'} value) { + value = ${input.getByIndices('input_indices')}; + best_index = i32(last_index); }`, - '', output.setByOffset('global_idx', 'bestIndex') + '', output.setByOffset('global_idx', 'best_index') ]; }; context.compute( createReduceProgramInfo( - 'ArgMin', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64, - attributes.keepDims), + 'ArgMin', {hint: attributes.cacheKey, inputDependencies: ['rank']}, [context.inputs[0]], argMinMaxOp, + [attributes.axis], DataType.int64, attributes.keepDims), {inputs: [0]}); }; @@ -59,23 +59,23 @@ export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes) const idxZero = []; for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(`inputIndices[${k}] = 0;`); // first element + idxZero.push(`input_indices[${k}] = 0;`); // first element } } return [ - `${idxZero.join('\n')}`, `var value = ${input.getByOffset('inputOffset')};\nvar bestIndex : i32 = 0;`, - `if (${input.getByOffset('inputOffset')} ${attributes.selectLastIndex > 0 ? '>=' : '>'} value) { - value = ${input.getByOffset('inputOffset')}; - bestIndex = i32(lastIndex); + `${idxZero.join('\n')}`, `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`, + `if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '>=' : '>'} value) { + value = ${input.getByIndices('input_indices')}; + best_index = i32(last_index); }`, - '', output.setByOffset('global_idx', 'bestIndex') + '', output.setByOffset('global_idx', 'best_index') ]; }; context.compute( createReduceProgramInfo( - 'argMax', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64, - attributes.keepDims), + 'argMax', {hint: attributes.cacheKey, inputDependencies: ['rank']}, [context.inputs[0]], argMinMaxOp, + [attributes.axis], DataType.int64, attributes.keepDims), {inputs: [0]}); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts new file mode 100644 index 0000000000000..e1f2a47301bfb --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -0,0 +1,635 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType} from '../types'; + +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; + +export const enum AttentionQkvFormat { + unknown, // enum value not set, or depends on qkv projection implementation details + qkvBNSH, // for non-packed qkv, permuted + qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + qkvBSN3H, // for TRT fused attention, qkv are packed + qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) + qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed + qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed +} + +export const enum AttentionMaskType { + none, // No mask + mask1dKeySeqLen, // [batch_size], key sequence length + mask1dEndStart, // [2 * batch_size] with end positions and start positions + mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], + // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., + // key_start[batch_size - 1], key_end[batch_size - 1]] + mask2dDummy, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. + mask2dKeyPadding, // [batch_size, total_sequence_length] + mask3dAttention, // [batch_size, sequence_length, total_sequence_length] + mask4dMegatron, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] + maskUnknown +} + +export interface AttentionParameters { + batchSize: number; + sequenceLength: number; + pastSequenceLength: number; + kvSequenceLength: number; + totalSequenceLength: number; + maxSequenceLength: number; + inputHiddenSize: number; + hiddenSize: number; + vHiddenSize: number; + headSize: number; + vHeadSize: number; + numHeads: number; + isUnidirectional: boolean; + pastPresentShareBuffer: boolean; + maskFilterValue: number; + maskType: AttentionMaskType; + scale: number; + broadcastResPosBias: boolean; + passPastInKv: boolean; + qkvFormat: AttentionQkvFormat; +} + +export interface AttentionAttrs { + numHeads: number; + isUnidirectional: number; + maskFilterValue: number; + scale: number; + doRotary: number; + qkvHiddenSizes: number[]; + pastPresentShareBuffer: boolean; +} + +const validateAttentionInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // When past state is used, Q, K and V should have same hidden size (unless we split it into past_key and past_value). + + // Input shapes: + // input (Q/K/V) : (B, S, D_i) + // weights (Q/K/V) : (D_i, D + D + D_v) + // bias (Q/K/V) : (D + D + D_v) + // mask_index : see below + // past (K/V) : (2, B, N, P, H) or NULL + // relative_position_bias : (B, N, S, T) or NULL + + // For mask_index, the following shapes are supported: + // NULL, (B, 1), (1, 1) + // (B), (2 * B), (3 * B + 2) + // (B, T) + // (B, S, T) + // (B, 1, M, M) + // + // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger + // than hidden dimension of Q, K and V. + + const input = inputs[0]; + const weights = inputs[1]; + const bias = inputs[2]; + const maskIndex = inputs[3]; + const past = inputs[4]; + const relativePositionBias = inputs[5]; + + if (past && relativePositionBias) { + throw new Error('Attention cannot have both past and relative_position_bias'); + } + + if (input.dims.length !== 3) { + throw new Error('Input "input" must have 3 dimensions'); + } + + const batchSize = input.dims[0]; + const sequenceLength = input.dims[1]; + const inputHiddenSize = input.dims[2]; + + if (bias.dims.length !== 1) { + throw new Error('Input "bias" is expected to have 1 dimensions'); + } + + if (weights.dims.length !== 2) { + throw new Error('Input "weights" is expected to have 2 dimensions'); + } + + if (weights.dims[0] !== inputHiddenSize) { + throw new Error('Input 1 dimension 0 should have same length as dimension 2 of input 0'); + } + + if (bias.dims[0] !== weights.dims[1]) { + throw new Error('Input "bias" dimension 0 should have same length as dimension 1 of input "weights"'); + } + + let qHiddenSize = bias.dims[0] / 3; + let kHiddenSize = qHiddenSize; + let vHiddenSize = kHiddenSize; + if (attributes.qkvHiddenSizes.length > 0) { + if (attributes.qkvHiddenSizes.length !== 3) { + throw new Error('qkv_hidden_sizes attribute should have 3 elements'); + } + for (const sz of attributes.qkvHiddenSizes) { + if (sz % attributes.numHeads !== 0) { + throw new Error('qkv_hidden_sizes should be divisible by num_heads'); + } + } + + qHiddenSize = attributes.qkvHiddenSizes[0]; + kHiddenSize = attributes.qkvHiddenSizes[1]; + vHiddenSize = attributes.qkvHiddenSizes[2]; + } + + const kvSequenceLength = sequenceLength; + + if (qHiddenSize !== kHiddenSize) { + throw new Error('qkv_hidden_sizes first element should be same as the second'); + } + + if (bias.dims[0] !== qHiddenSize + kHiddenSize + vHiddenSize) { + throw new Error('Input "bias" dimension 0 should have same length as sum of Q/K/V hidden sizes'); + } + + let pastSequenceLength = 0; + if (past) { + if (kHiddenSize !== vHiddenSize) { + throw new Error('Input "past" expect k_hidden_size == v_hidden_size'); + } + if (past.dims.length !== 5) { + throw new Error('Input "past" must have 5 dimensions'); + } + if (past.dims[0] !== 2) { + throw new Error('Input "past" first dimension must be 2'); + } + if (past.dims[1] !== batchSize) { + throw new Error('Input "past" second dimension must be batch_size'); + } + if (past.dims[2] !== attributes.numHeads) { + throw new Error('Input "past" third dimension must be num_heads'); + } + if (past.dims[4] !== kHiddenSize / attributes.numHeads) { + throw new Error('Input "past" fifth dimension must be k_hidden_size / num_heads'); + } + + if (!attributes.pastPresentShareBuffer) { + pastSequenceLength = past.dims[3]; + } + // TODO: handle past_seq_len + } + + const totalSequenceLength = kvSequenceLength + pastSequenceLength; + const maxSequenceLength = -1; + + const maskType = AttentionMaskType.none; + if (maskIndex) { + // maskType = AttentionMaskType.MASK_UNKNOWN; + // TODO: handle mask + throw new Error('Mask not supported'); + } + + if (past) { + throw new Error('past is not supported'); + } + if (relativePositionBias) { + throw new Error('relativePositionBias is not supported'); + } + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize, + hiddenSize: qHiddenSize, + vHiddenSize, + headSize: Math.floor(qHiddenSize / attributes.numHeads), + vHeadSize: Math.floor(vHiddenSize / attributes.numHeads), + numHeads: attributes.numHeads, + isUnidirectional: false, + pastPresentShareBuffer: false, + maskFilterValue: attributes.maskFilterValue, + maskType, + scale: attributes.scale, + broadcastResPosBias: false, + passPastInKv: false, + qkvFormat: AttentionQkvFormat.qkvBNSH, + }; +}; + +export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, n: number, d: number) => { + const components = getMaxComponents(d); + const inputHelper = outputVariable('x', input.dataType, input.dims, components); + + let threadMaxValue = 'threadMaxVector'; + if (components === 2) { + threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)'; + } else if (components === 4) { + threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))'; + } + const dataType = tensorTypeToWsglStorageType(input.dataType); + let WG = 64; + const dComp = d / components; + if (dComp < WG) { + WG = 1; + } else if (dComp / 8 < 64) { + WG = Math.ceil(dComp / 8); + } + const elementsPerWG = Math.ceil(d / components / WG); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const dInv: ${dataType} = 1 / ${d}; + const dComp = ${d / components}; + var wgMax: array; + var wgSum: array; + + ${shaderHelper.declareVariables(inputHelper)} + @compute @workgroup_size(${WG}, 1, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_index) local_index : u32) { + let localOffset = local_index * ${elementsPerWG}; + let offset: u32 = workgroup_id.x * dComp + localOffset; + + var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')}; + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + threadMaxVector = max(${castToF32(dataType, components, 'x[offset + i]')}, threadMaxVector); + } + wgMax[local_index] = ${threadMaxValue}; + workgroupBarrier(); + + var maxValue = -3.402823e+38f; + for (var i = 0u; i < ${WG}; i++) { + maxValue = max(wgMax[i], maxValue); + } + + var sumVector = ${fillVector('f32', components, '0')}; + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + sumVector += exp(${castToF32(dataType, components, 'x[offset + i]')} - maxValue); + } + wgSum[local_index] = ${sumVector('sumVector', components)}; + workgroupBarrier(); + + var sum: f32 = 0; + for (var i = 0u; i < ${WG}; i++) { + sum += wgSum[i]; + } + + if (sum == 0) { + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + x[offset + i] = ${fillVector(dataType, components, 'dInv')}; + } + } else { + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + let f32input = ${castToF32(dataType, components, 'x[offset + i]')}; + x[offset + i] = ${inputHelper.type.value}(exp(f32input - maxValue) / sum); + } + } + }`; + + context.compute( + { + name: 'AttentionProbsSoftmax', + shaderCache: {hint: `${d}`}, + getShaderSource, + getRunData: () => ({ + outputs: [], + dispatchGroup: {x: n}, + }), + }, + {inputs: [input], outputs: []}); +}; + +const computeAttentionProbs = + (context: ComputeContext, q: TensorView, key: TensorView, _bias: TensorView|undefined, + parameters: AttentionParameters, attributes: AttentionAttrs) => { + const probsShape = [ + parameters.batchSize, parameters.numHeads, parameters.sequenceLength, + parameters.kvSequenceLength + parameters.pastSequenceLength + ]; + // TODO: handle mask + + const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; + + const dataType = tensorTypeToWsglStorageType(q.dataType); + + const components = getMaxComponents(parameters.headSize); + const qInput = inputVariable('q', q.dataType, q.dims, components); + const kInput = inputVariable('key', key.dataType, key.dims, components); + const output = outputVariable('output', q.dataType, probsShape); + + const vectorizedHeadSize = parameters.headSize / components; + const M = parameters.sequenceLength; + const N = parameters.totalSequenceLength; + const K = vectorizedHeadSize; + + const TILE_SIZE = 12; + + const dispatch = { + x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + + const inputs = [q, key]; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${M}u; + const N: u32 = ${N}u; + const K: u32 = ${K}u; + const alpha: ${dataType} = ${alpha}; + const beta: ${dataType} = 1.0; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + + ${shaderHelper.declareVariables(qInput, kInput, output)} + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + // x holds the N and y holds the M + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE; + let n = workgroup_id.x * TILE_SIZE; + let lm = m + local_id.y; + let ln = n + local_id.x; + + let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; + let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; + + var value = ${fillVector(dataType, components)}; + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m + local_id.y < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x]; + } + if (n + local_id.y < N && w + local_id.x < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * K + w + local_id.x]; + } + workgroupBarrier(); + + for (var k: u32 = 0u; k ({ + outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs, outputs: [-1]})[0]; + + computeInPlaceSoftmax( + context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength, + parameters.totalSequenceLength); + + return probs; + }; + +const computeVxAttentionScore = + (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { + const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; + + const probsHelper = inputVariable('probs', probs.dataType, probs.dims); + const vHelper = inputVariable('v', v.dataType, v.dims); + const output = outputVariable('output', probs.dataType, outputShape); + + const dataType = tensorTypeToWsglStorageType(probs.dataType); + + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(params.vHeadSize / TILE_SIZE), + y: Math.ceil(params.sequenceLength / TILE_SIZE), + z: params.batchSize * params.numHeads + }; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${params.sequenceLength}u; + const N: u32 = ${params.vHeadSize}u; + const K: u32 = ${params.totalSequenceLength}u; + const numHeads: u32 = ${params.numHeads}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + + ${shaderHelper.declareVariables(probsHelper, vHelper, output)} + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; + + let offsetA = headIdx * (M * K) + m * K; + let offsetB = headIdx * (N * K) + n; + + var value = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({ + outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs: [probs, v], outputs: [0]})[0]; + }; + +export const applyAttention = + (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined, + _past: TensorView|undefined, _pastKey: TensorView|undefined, _pastValue: TensorView|undefined, + relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { + const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes); + + computeVxAttentionScore(context, probs, v, parameters); + }; + +const prepare = (context: ComputeContext, parameters: AttentionParameters) => { + const outputShape = [ + parameters.batchSize, + parameters.numHeads, + parameters.sequenceLength, + parameters.headSize, + ]; + + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + + const M = parameters.sequenceLength; + const K = parameters.inputHiddenSize; + const N = parameters.headSize; + + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(parameters.headSize / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + + const getShaderSource = () => ` + const M: u32 = ${M}u; + const K: u32 = ${K}u; + const N: u32 = ${N}u; + const numHeads: u32 = ${parameters.numHeads}; + const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + + @group(0) @binding(0) var input: array<${dataType}>; + @group(0) @binding(1) var weight: array<${dataType}>; + @group(0) @binding(2) var bias: array<${dataType}>; + @group(0) @binding(3) var outputQ: array<${dataType}>; + @group(0) @binding(4) var outputK: array<${dataType}>; + @group(0) @binding(5) var outputV: array<${dataType}>; + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + let batchIndex = workgroup_id.z / ${parameters.numHeads}; + let headNumber = workgroup_id.z % ${parameters.numHeads}; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; + + let inputOffset = batchIndex * (M * K) + m * K; + let biasOffsetQ = headNumber * ${parameters.headSize}; + let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ; + let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK; + + var valueQ = ${dataType}(0); + var valueK = ${dataType}(0); + var valueV = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + let offset = n + (w + local_id.y) * ldb; + tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset]; + tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset]; + tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({ + outputs: [ + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + ], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs, outputs: [-1, -1, -1]}); +}; + +export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateAttentionInputs(context.inputs, attributes); + + const [q, k, v] = prepare(context, params); + + return applyAttention( + context, q, k, v, context.inputs[4], undefined, undefined, undefined, context.inputs[5], params, attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts new file mode 100644 index 0000000000000..ec9da2613f406 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env} from 'onnxruntime-common'; + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; + +export interface BatchNormAttributes extends AttributeWithCacheKey { + readonly epsilon: number; + readonly momentum: number; + readonly spatial: boolean; + readonly trainingMode: boolean; + readonly format: 'NHWC'|'NCHW'; + readonly outputCount: number; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: BatchNormAttributes): void => { + if (!inputs || inputs.length !== 5) { + throw new Error('BatchNormalization requires 5 inputs'); + } + + const checkShapeEqual = (actual: readonly number[], expected: readonly number[], message: string) => { + const r = expected.length; + if (r !== actual.length) { + throw new Error(`${message}: num dimensions != ${r}`); + } + expected.forEach((v, i) => { + if (v !== actual[i]) { + throw new Error(`${message}: dim[${i}] do not match`); + } + }); + }; + + if (inputs[0].dims.length > 1) { + const shape = attributes.format === 'NHWC' ? + (attributes.spatial ? inputs[0].dims.slice(-1) : + inputs[0].dims.slice(-1).concat(inputs[0].dims.slice(1, inputs[0].dims.length - 1))) : + inputs[0].dims.slice(1, attributes.spatial ? 2 : undefined); + checkShapeEqual(inputs[1].dims, shape, 'Invalid input scale'); + checkShapeEqual(inputs[2].dims, shape, 'Invalid input B'); + checkShapeEqual(inputs[3].dims, shape, 'Invalid input mean'); + checkShapeEqual(inputs[4].dims, shape, 'Invalid input var'); + } else { + checkShapeEqual(inputs[1].dims, [1], 'Invalid input scale'); + checkShapeEqual(inputs[2].dims, [1], 'Invalid input B'); + checkShapeEqual(inputs[3].dims, [1], 'Invalid input mean'); + checkShapeEqual(inputs[4].dims, [1], 'Invalid input var'); + } +}; + +const createBatchNormInferenceProgramInfo = + (inputs: readonly TensorView[], attributes: BatchNormAttributes): ProgramInfo => { + const {epsilon, spatial, format} = attributes; + const yShape = inputs[0].dims; + const components = spatial ? getMaxComponents(yShape[yShape.length - 1]) : 1; + const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1; + const outputSize = ShapeUtil.size(yShape) / components; + // Only support uniforms for opset version >= 9 (spatial = true). + const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial; + const shapeOrRank = useShapesUniforms ? yShape.length : yShape; + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); + const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims, cComponents); + const inputMean = inputVariable('inputMean', inputs[3].dataType, inputs[3].dims, cComponents); + const inputVar = inputVariable('inputVar', inputs[4].dataType, inputs[4].dims, cComponents); + const y = outputVariable('y', inputs[0].dataType, shapeOrRank, components); + // TODO: support inputs with different data type. Current we need to make sure all inputs have the same data type. + // Otherwise, the shader compilation will fail. + const calcCOffset = (): string => { + let cOffset = ''; + if (spatial) { + cOffset = `let cOffset = ${ + yShape.length === 1 ? '0u' : + format === 'NHWC' ? `outputIndices[${yShape.length - 1}] / ${components}` : + 'outputIndices[1]'};`; + } else { + if (format === 'NCHW') { + cOffset = ` + ${y.indicesSet('outputIndices', '0', '0')} + let cOffset = ${y.indicesToOffset('outputIndices')};`; + } else { + // update C channel. + cOffset = `var cIndices = ${scale.type.indices}(0); + cIndices[0] = outputIndices[${yShape.length - 1}];`; + // update D1 x ... x Dn channels. + for (let i = 1; i < scale.rank; i++) { + cOffset += `cIndices[${i}] = outputIndices[${i}];`; + } + cOffset += `let cOffset = ${scale.indicesToOffset('cIndices')};`; + } + } + return cOffset; + }; + const getInferenceModeShaderSource = (helper: ShaderHelper) => ` + const epsilon = ${epsilon}; + ${helper.registerUniform('outputSize', 'u32').declareVariables(x, scale, bias, inputMean, inputVar, y)} + ${helper.mainStart()} + ${helper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + var outputIndices = ${y.offsetToIndices(`global_idx * ${components}`)}; + ${calcCOffset()} + let scale = ${scale.getByOffset('cOffset')}; + let bias = ${bias.getByOffset('cOffset')}; + let inputMean = ${inputMean.getByOffset('cOffset')}; + let inputVar = ${inputVar.getByOffset('cOffset')}; + let x = ${x.getByOffset('global_idx')}; + let value = (x - inputMean) / sqrt(inputVar + epsilon) * scale + bias; + ${y.setByOffset('global_idx', 'value')} + }`; + return { + name: 'BatchNormalization', + shaderCache: { + hint: `${attributes.epsilon}_${attributes.format}_${spatial}_${components}`, + inputDependencies: useShapesUniforms ? ['rank', 'type', 'type', 'type', 'type'] : undefined, + }, + getShaderSource: getInferenceModeShaderSource, + getRunData: () => ({ + outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: useShapesUniforms ? + [ + {type: 'uint32', data: outputSize}, + ...createTensorShapeVariables(yShape), + ] : + [ + {type: 'uint32', data: outputSize}, + ], + }), + }; + }; + +export const parseBatchNormAttributes = (attributes: Record): BatchNormAttributes => + createAttributeWithCacheKey(attributes as Omit); + +export const batchNorm = (context: ComputeContext, attributes: Record): void => { + const {inputs, outputCount} = context; + const updatedAttributes = parseBatchNormAttributes({...attributes, outputCount}); + if (env.webgpu.validateInputContent) { + validateInputs(inputs, updatedAttributes); + } + if (attributes.trainingMode) { + throw new Error('BatchNormalization trainingMode is not supported yet.'); + } else { + context.compute(createBatchNormInferenceProgramInfo(inputs, updatedAttributes)); + } +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts index 14eefc344f3c0..a81a7a8f1df5c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -5,7 +5,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; import {erfImpl} from './unary-op'; const validateInputs = (inputs: readonly TensorView[]): void => { @@ -35,6 +35,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI const output = outputVariable('output', inputs[0].dataType, outputShape, 4); const outputSize = ShapeUtil.size(outputShape) / 4; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const getShaderSource = (shaderHelper: ShaderHelper) => ` const M_SQRT2 = sqrt(2.0); @@ -42,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI ${shaderHelper.declareVariables(input, bias, output)} - ${erfImpl('vec4f')} + ${erfImpl(`vec4<${dataType}>`, dataType)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 0841da11d9e86..c033c0ba05356 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -17,8 +17,9 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ const createBinaryOpProgramShader = (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], - vectorize: boolean, doBroadcast: boolean, funcCall: BinaryFunctionCall, typeA: number, typeB: number, - typeOutput: number, useShapesUniforms: boolean, additionalImplementation?: string) => { + vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall, + typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean, + additionalImplementation?: string) => { let expressionScalar: BinaryCustomExpression; let expressionVector: BinaryCustomExpression; if (typeof funcCall === 'string') { @@ -42,6 +43,8 @@ const createBinaryOpProgramShader = if (doBroadcast) { const isAOneElement = ShapeUtil.size(dimsA) === 1; const isBOneElement = ShapeUtil.size(dimsB) === 1; + const aLastDimDivisibleBy4 = dimsA.length > 0 && dimsA[dimsA.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = dimsB.length > 0 && dimsB[dimsB.length - 1] % 4 === 0; if (isAOneElement || isBOneElement) { assignment = output.setByOffset( 'global_idx', @@ -55,7 +58,14 @@ const createBinaryOpProgramShader = let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)}; ${ output.setByOffset( - 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} + 'global_idx', + expressionVector( + sharedDimensionDivisibleBy4 || aLastDimDivisibleBy4 ? + a.getByOffset('offsetA / 4u') : + `${a.type.value}(${a.getByOffset('offsetA / 4u')}[offsetA % 4u])`, + sharedDimensionDivisibleBy4 || bLastDimDivisibleBy4 ? + b.getByOffset('offsetB / 4u') : + `${b.type.value}(${b.getByOffset('offsetB / 4u')}[offsetB % 4u])`))} `; } } else { @@ -118,6 +128,7 @@ const createBinaryOpProgramInfo = let outputSize = ShapeUtil.size(a.dims); let vectorize = false; + let sharedDimensionDivisibleBy4 = false; // TODO: deal with zero-sized tensors (eg. dims=[1,0]) const cacheKeyAux = [isBroadcast]; @@ -130,8 +141,12 @@ const createBinaryOpProgramInfo = outputSize = ShapeUtil.size(outputShape); const isAOneElement = ShapeUtil.size(a.dims) === 1; const isBOneElement = ShapeUtil.size(b.dims) === 1; + const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0; cacheKeyAux.push(isAOneElement); cacheKeyAux.push(isBOneElement); + cacheKeyAux.push(aLastDimDivisibleBy4); + cacheKeyAux.push(bLastDimDivisibleBy4); // check whether vectorize can be enabled let sharedDimension = 1; for (let i = 1; i < outputShape.length; i++) { @@ -143,7 +158,10 @@ const createBinaryOpProgramInfo = break; } } - if (sharedDimension % 4 === 0 || isAOneElement || isBOneElement) { + if (sharedDimension % 4 === 0) { + sharedDimensionDivisibleBy4 = true; + vectorize = true; + } else if (isAOneElement || isBOneElement || aLastDimDivisibleBy4 || bLastDimDivisibleBy4) { vectorize = true; } } else { @@ -160,8 +178,8 @@ const createBinaryOpProgramInfo = inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'], }, getShaderSource: (shaderHelper) => createBinaryOpProgramShader( - shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType, - outputDataType, useShapesUniforms, additionalImplementation), + shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall, + a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation), getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 38dc14f23682e..0eb0d40a3ea5e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -58,10 +58,11 @@ interface IndicesHelperTypes { * create an instance of an indices helper: * - `inputVariable()`: create an indices helper instance for an input. * - `outputVariable()`: create an indices helper instance for an output. + * - `internalVariable()`: create an indices helper instance for an internal variable. * * An indices helper instance contains helper functions for the following operations: * - access readonly basic information, including: `name`(the name of the input or output), `usage`(whether it's an - * input or an output) and `shape`(the passed in shape). + * input, an output or an internal variable) and `shape`(the passed in shape). * - `type`: access readonly type information, including: `indices`(the type of indices), `value`(the type of value at * runtime), `storage`(the type of value at storage) and `tensor`(the tensor type as represented in TensorView). * - generate WGSL code for getting indices from offset. Use `offsetToIndices()` for WGSL code snippet to calculate @@ -192,9 +193,9 @@ export interface IndicesHelper { readonly name: string; /** - * whether the helper is for an input or an output. + * whether the helper is for an input, an output or an internal variable. */ - readonly usage: 'input'|'output'; + readonly usage: 'input'|'output'|'internal'; /** * the rank of the input or output. @@ -324,18 +325,36 @@ export const sumVector = (name: string, components: number) => { return name; }; +/** + * A helper function that returns variable element at index. + * @param name - the name of variable. + * @param index - the index of variable element. + * @param length - the length of variable. + */ +export const getElementAt = (name: string, index: number|string, length: number): string => { + if (name.startsWith('uniforms.') && length > 4) { + if (typeof (index) === 'string') { + return `${name}[(${index}) / 4][(${index}) % 4]`; + } else { + return `${name}[${Math.floor(index / 4)}][${index % 4}]`; + } + } else { + return length > 1 ? `${name}[${index}]` : name; + } +}; + /** * A helper function to get a IndicesHelper for a given input or output. * * @param name - the name of the input or output. * @param tensorType - the tensor type of the input or output. * @param shapeOrRank - the tensor shape or the rank of the input or output. - * @param isInput - whether the helper is for an input or an output. + * @param usage - the usage of the indices helper. * @param components - indicates the number of components of each element. 1 for scalar, 2 for vec2, 3 for vec3, 4 for * vec4. */ const createIndicesHelper = - (name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, + (name: string, tensorType: number, shapeOrRank: number|readonly number[], usage: IndicesHelper['usage'], components: 1|2|3|4): IndicesHelper => { const useUniform = typeof shapeOrRank === 'number'; const rank = useUniform ? shapeOrRank : shapeOrRank.length; @@ -361,11 +380,12 @@ const createIndicesHelper = const uniformPrefix = useUniform ? 'uniforms.' : ''; const shape = `${uniformPrefix}${name}_shape`; const strides = `${uniformPrefix}${name}_strides`; + let o2iSnippet = ''; for (let i = 0; i < rank - 1; i++) { o2iSnippet += ` - let dim${i} = current / ${strides}[${i}]; - let rest${i} = current % ${strides}[${i}]; + let dim${i} = current / ${getElementAt(strides, i, rank)}; + let rest${i} = current % ${getElementAt(strides, i, rank)}; indices[${i}] = dim${i}; current = rest${i}; `; @@ -388,7 +408,7 @@ const createIndicesHelper = const offsets: string[] = []; if (rank >= 2) { for (let i = rank - 1; i >= 0; i--) { - offsets.push(`${strides}[${i}] * (indices[${i}])`); + offsets.push(`${getElementAt(strides, i, rank)} * (indices[${i}])`); } } @@ -409,7 +429,7 @@ const createIndicesHelper = if (rank < 2) { return `${varIndices}`; } else { - return `${varIndices}[${idx}]`; + return `${getElementAt(varIndices, idx, rank)}`; } }; @@ -417,7 +437,7 @@ const createIndicesHelper = if (rank < 2) { return `${varIndices}=${value};`; } else { - return `${varIndices}[${idx}]=${value};`; + return `${getElementAt(varIndices, idx, rank)}=${value};`; } }; @@ -612,7 +632,7 @@ const createIndicesHelper = getByOffset, getByIndices, // isVec4, - usage: isInput ? 'input' : 'output', + usage, name, strides, shape, @@ -631,7 +651,7 @@ const createIndicesHelper = */ export const inputVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shapeOrRank, true, components); + createIndicesHelper(name, type, shapeOrRank, 'input', components); /** * Create a IndicesHelper for an output. @@ -644,7 +664,23 @@ export const inputVariable = */ export const outputVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shapeOrRank, false, components); + createIndicesHelper(name, type, shapeOrRank, 'output', components); + +/** + * Create a IndicesHelper for an internal variable. + * + * @param name - the name of the variable. + * @param type - the tensor type of the variable. + * @param shapeOrRank - the tensor shape or the rank of the variable. + * @param components - the number of components of the variable. available values are 1, 2, 3, 4. default is 1. + * @returns an IndicesHelper for the variable. + */ +export const internalVariable = + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shapeOrRank, 'internal', components); + +export type UniformDataElementType = 'u32'|'f32'|'i32'; +export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>; /** * A ShaderHelper is a helper class for generating WGSL code. @@ -695,8 +731,28 @@ export interface ShaderHelper { /** * A helper function to register one uniform. Can be called multiple times to register multiple uniforms. + * + * @param name - the name of the uniform. + * @param type - the type of the uniform. + * @param length - the length of the uniform, default to 1 when it is not provided. + */ + registerUniform(name: string, type: string, length?: number): ShaderHelper; + + /** + * A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms. + * + * @param uniforms - an array of uniforms. Each element of the array is an object with 2 properties: `name` and + * `type`. + */ + registerUniforms(uniforms: UniformsArrayType): ShaderHelper; + + /** + * A helper function to register multiple internal variables. Can be called multiple times to register multiple + * internal variables. + * + * @param variables - an array of IndicesHelper for the variables. */ - registerUniform(name: string, type: string): ShaderHelper; + registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper; } class ShaderHelperImpl implements ShaderHelper { @@ -716,14 +772,14 @@ class ShaderHelperImpl implements ShaderHelper { const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1; const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, @builtin(local_invocation_id) local_id : vec3` : - `@builtin(local_invocation_index) local_index : u32, + `@builtin(local_invocation_index) local_idx : u32, @builtin(workgroup_id) workgroup_id : vec3, @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch ? - 'let global_idx = global_id.x;' : + 'let global_idx = global_id.x; let local_idx = local_id.x;' : `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${ - workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_index;`; + workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_idx;`; return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ}) fn main(${paramList}) { @@ -731,16 +787,24 @@ class ShaderHelperImpl implements ShaderHelper { `; } - private declareVariable(variable: IndicesHelper, bindingIndex: number): string { - this.indicesHelpers.push(variable); + private appendVariableUniforms(variable: IndicesHelper): void { if (variable.rank !== 0) { if (variable.shape.startsWith('uniforms.')) { - this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices}); + this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: 'u32', length: variable.rank}); } if (variable.strides.startsWith('uniforms.')) { - this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices}); + this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: 'u32', length: variable.rank}); } } + } + + private declareVariable(variable: IndicesHelper, bindingIndex: number): string { + if (variable.usage === 'internal') { + throw new Error('cannot use internal variable with declareVariable(). use registerInternalVariables() instead.'); + } + this.variables.push(variable); + this.appendVariableUniforms(variable); + const access = variable.usage === 'input' ? 'read' : 'read_write'; const storageType = variable.type.storage; return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; @@ -750,21 +814,47 @@ class ShaderHelperImpl implements ShaderHelper { return variables.map(v => this.declareVariable(v, this.variableIndex++)).join('\n'); } - registerUniform(name: string, type: string): ShaderHelper { - this.uniforms.push({name, type}); + private registerInternalVariable(variable: IndicesHelper): void { + if (variable.usage !== 'internal') { + throw new Error( + 'cannot use input or output variable with registerInternalVariable(). use declareVariables() instead.'); + } + + this.internalVariables.push(variable); + this.appendVariableUniforms(variable); + } + + registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper { + variables.forEach(v => this.registerInternalVariable(v)); + return this; + } + + registerUniform(name: string, type: UniformDataElementType, length = 1): ShaderHelper { + this.uniforms.push({name, type, length}); return this; } - private indicesHelpers: IndicesHelper[] = []; - private uniforms: Array<{name: string; type: string}> = []; + registerUniforms(additionalUniforms: UniformsArrayType): ShaderHelper { + this.uniforms = this.uniforms.concat(additionalUniforms); + return this; + } + + private internalVariables: IndicesHelper[] = []; + private variables: IndicesHelper[] = []; + private uniforms: UniformsArrayType = []; private uniformDeclaration(): string { if (this.uniforms.length === 0) { return ''; } const uniformSnippets: string[] = []; - for (const {name, type} of this.uniforms) { - uniformSnippets.push(`${name}:${type}`); + for (const {name, type, length} of this.uniforms) { + if (length && length > 4) { + uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + } else { + const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`; + uniformSnippets.push(`${name}:${typeTemp}`); + } } return ` @@ -777,7 +867,8 @@ class ShaderHelperImpl implements ShaderHelper { * Get additional implementation that needs to be added to the shader source. */ get additionalImplementations(): string { - return this.uniformDeclaration() + this.indicesHelpers.map(i => i.impl()).join('\n'); + return this.uniformDeclaration() + this.variables.map(i => i.impl()).join('\n') + + this.internalVariables.map(i => i.impl()).join('\n'); } } @@ -807,5 +898,5 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly return dims; }; -// TODO: remove this limitation once >4D dims are supported by uniform. -export const enableShapesUniforms = (rank: number): boolean => rank <= 4; +// TODO: remove this when all related uses have been removed. +export const enableShapesUniforms = (_rank: number): boolean => true; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index e880afe09a5d8..32b1d52ed94ca 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -209,18 +209,20 @@ const convTranspose2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); const isChannelsLast = attributes.format === 'NHWC'; - const hasBias = inputs.length === 3; - if (adjustedAttributes.group !== 1) { + const outputShape = adjustedAttributes.outputShape; + const outChannels = outputShape[isChannelsLast ? 3 : 1]; + const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; + // Switch to naive method when outChannels and inputChannels are very small. It's because that in this case it's + // not suitable for matmul version since matmul uses tile size 32x32 resulting the underlying execution unit + // utilization rate is very low. + if (adjustedAttributes.group !== 1 || (outChannels === 1 && inputChannels === 1)) { context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes)); return; } - const outputShape = adjustedAttributes.outputShape; const outHeight = outputShape[isChannelsLast ? 1 : 2]; const outWidth = outputShape[isChannelsLast ? 2 : 3]; - const outChannels = outputShape[isChannelsLast ? 3 : 1]; const weightHeight = inputs[1].dims[2]; const weightWidth = inputs[1].dims[3]; - const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; @@ -240,6 +242,7 @@ const convTranspose2d = // STEP.2: prepare reshaped inputs const convTransposeInputs = [inputs[0], transposedWeight]; + const hasBias = inputs.length === 3; if (hasBias) { if (!isChannelsLast && inputs[2].dims.length === 1) { convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index c7ea0cffe51c3..33a5db7ff6b25 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -10,6 +10,7 @@ import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; import {createGroupedConvProgramInfo} from './conv-grouped'; import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; +import {createNaiveMatmulProgramInfo} from './matmul'; import {createTransposeProgramInfo} from './transpose'; export const calculateOutputShape = @@ -195,9 +196,19 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut if (hasBias) { matmulInputs.push(inputs[2]); } - context.compute( - createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), - {inputs: matmulInputs}); + const N = matmulOutputShape[2]; + const K = matmulInputs[0].dims[matmulInputs[0].dims.length - 1]; + // Tune the threshold. + if (N < 8 && K < 8) { + context.compute( + createNaiveMatmulProgramInfo( + matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), + {inputs: matmulInputs}); + } else { + context.compute( + createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), + {inputs: matmulInputs}); + } return; } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts new file mode 100644 index 0000000000000..2ff909c30e62e --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper} from './common'; + + +export interface CumSumAttributes extends AttributeWithCacheKey { + readonly exclusive: boolean; + readonly reverse: boolean; +} +const createCumsumProgramInfo = + (inputType: number, inputShape: readonly number[], axisInput: TensorView, attributes: CumSumAttributes): + ProgramInfo => { + const outputSize = ShapeUtil.size(inputShape); // outputShape is same as inputShape. + const rank = inputShape.length; // input/output rank + const input = inputVariable('input', inputType, rank); + const output = outputVariable('output', inputType, rank); + const axisValue = axisInput.dataType === DataType.int32 ? axisInput.getInt32Array()[0] : + Number(axisInput.getBigInt64Array()[0]); + const axis = ShapeUtil.normalizeAxis(axisValue, rank); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const index = ` i32(${input.indicesGet('inputIndices', 'uniforms.axis')}) `; + const max = getElementAt('uniforms.input_shape', 'uniforms.axis', rank); + const lowerLimit = attributes.reverse ? index + (attributes.exclusive ? ' + 1' : '') : '0'; + const upperLimit = attributes.reverse ? max : index + (attributes.exclusive ? '' : ' + 1'); + return ` + ${ + shaderHelper.registerUniform('outputSize', 'u32') + .registerUniform('axis', 'u32') + .declareVariables(input, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + var inputIndices = ${output.offsetToIndices('global_idx')}; + var sum = ${output.type.value}(0); + let first : i32 = ${lowerLimit}; + let last : i32 = ${upperLimit}; + for (var i : i32 = first; i < last; i++) { + ${input.indicesSet('inputIndices', 'uniforms.axis', 'u32(i)')}; + sum = sum + ${input.getByIndices('inputIndices')}; + } + ${output.setByOffset('global_idx', 'sum')}; + }`; + }; + return { + name: 'CumSum', + shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank']}, + getRunData: () => ({ + outputs: [{dims: inputShape, dataType: inputType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: [ + {type: 'uint32', data: outputSize}, {type: 'int32', data: axis}, + ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape) + ] + + }), + getShaderSource + }; + }; + + +export const cumsum = (context: ComputeContext, attributes: CumSumAttributes): void => { + const inputShape = context.inputs[0].dims; + const inputType = context.inputs[0].dataType; + const axis = context.inputs[1]; + context.compute(createCumsumProgramInfo(inputType, inputShape, axis, attributes), {inputs: [0]}); +}; + +export const parseCumSumAttributes = (attributes: Record): CumSumAttributes => { + const exclusive = attributes.exclusive as number === 1; + const reverse = attributes.reverse as number === 1; + return createAttributeWithCacheKey({exclusive, reverse}); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index a233d37a79e65..4db7c04ad67be 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -4,9 +4,10 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface EinsumAttributes extends AttributeWithCacheKey { readonly equation: string; @@ -101,7 +102,7 @@ class EinsumEquation { this.outputDims.push(info.dimValue); } }); - this.rhs = this.processTerm(rhs, true, this.outputDims); + this.rhs = this.processTerm(rhs, false, this.outputDims); } // End of EinsumEqation constructor // Add a symbol to the equation @@ -157,12 +158,12 @@ class EinsumEquation { } // Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling for (let j = 0; j < ellipsisDims.length; j++) { - const symbol = String.fromCharCode('0'.charCodeAt(0) + i); + const symbol = String.fromCharCode('0'.charCodeAt(0) + j); einsumTerm.addSymbol(symbol, i + j); this.addSymbol(symbol, dims[nextDim++], index); } } else { - einsumTerm.addSymbol(symbol, i); + einsumTerm.addSymbol(symbol, i + (this.hasEllipsis ? this.ellipsisDims.length - 1 : 0)); this.addSymbol(symbol, dims[nextDim++], index); } }); @@ -177,101 +178,132 @@ class EinsumEquation { outputDims: number[]; // Output dimensions of the equation } // End of class EinsumEquation -const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation: EinsumEquation): ProgramInfo => { - const dataType = inputs[0].dataType; - const inputVars = new Array(inputs.length); - for (let i = 0; i < inputs.length; ++i) { - inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims); - } - const outputShape = einsumEquation.outputDims; - const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', dataType, outputShape); - const idxCopy: string[] = []; - const rhsSymbols = Array.from(einsumEquation.rhs.symbolToIndices.keys()); - const initProd = 'var prod = 1.0;'; - const initSum = 'var sum = 0.0;'; - const updateSum = 'sum += prod;'; - const reduceOpsSetIndices: string[] = []; - const reduceOpsLoopHeaders: string[] = []; - const reduceOpsLoopFooters: string[] = []; - const reduceOpCompute: string[] = []; - const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === rhsSymbols.length; - einsumEquation.symbolToInfo.forEach((info, symbol) => { - if (rhsSymbols.includes(symbol)) { - const outputIndex = rhsSymbols.indexOf(symbol); - einsumEquation.lhs.forEach((term, i) => { - if (info.inputIndices.includes(i)) { - const indices = term.symbolToIndices.get(symbol); - if (indices === undefined) { - throw new Error('Invalid symbol error'); +const appendMax = (name: string): string => name + '_max'; + +const createEinsumProgramInfo = + (enableInputShapesUniforms: readonly boolean[], inputShapes: Array, dataType: number, + einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => { + const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims); + const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank)); + const outputSize = ShapeUtil.size(outputShape); + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + const output = outputVariable('output', dataType, outputShapeOrRank); + const uniformsSymbols = + [...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const idxCopy: string[] = []; + const initProd = 'var prod = 1.0;'; + const initSum = 'var sum = 0.0;'; + const updateSum = 'sum += prod;'; + const reduceOpsSetIndices: string[] = []; + const reduceOpsLoopHeaders: string[] = []; + const reduceOpsLoopFooters: string[] = []; + const reduceOpCompute: string[] = []; + const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size; + einsumEquation.symbolToInfo.forEach((info, symbol) => { + if (einsumEquation.rhs.symbolToIndices.has(symbol)) { + const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0]; + if (outputIndex !== undefined) { + einsumEquation.lhs.forEach((term, i) => { + if (info.inputIndices.includes(i)) { + const indices = term.symbolToIndices.get(symbol); + if (indices === undefined) { + throw new Error('Invalid symbol error'); + } + indices.forEach((index) => { + idxCopy.push(`${ + inputVars[i].indicesSet( + `input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`); + }); + } + }); + } + } else { + einsumEquation.lhs.forEach((term, i) => { + if (info.inputIndices.includes(i)) { + const indices = term.symbolToIndices.get(symbol); + if (indices === undefined) { + throw new Error('Invalid symbol error'); + } + indices.forEach((index) => { + reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`); + }); + reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`); + } + }); + reduceOpsLoopHeaders.push( + `for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`); + reduceOpsLoopFooters.push('}'); } - indices.forEach((index) => { - idxCopy.push(`${ - inputVars[i].indicesSet(`input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`); - }); - } - }); - } else { - einsumEquation.lhs.forEach((term, i) => { - const info = einsumEquation.symbolToInfo.get(symbol); - if (info === undefined) { - throw new Error('Invalid symbol error'); - } - if (info.inputIndices.includes(i)) { - const indices = term.symbolToIndices.get(symbol); - if (indices === undefined) { - throw new Error('Invalid symbol error'); + }); + const reduceOps = isReduceOpsWithoutLoop ? + [ + ...idxCopy, + `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};` + ] : + [ + ...idxCopy, + initSum, + ...reduceOpsLoopHeaders, + ...reduceOpsSetIndices, + initProd, + ...reduceOpCompute, + updateSum, + ...reduceOpsLoopFooters, + ]; + return ` + ${ + shaderHelper + .registerUniforms(uniformsSymbols.map((symbol) => ({name: `${appendMax(symbol)}`, type: 'u32'}))) + .registerUniform('outputSize', 'u32') + .declareVariables(...inputVars, output)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + var outputIndices = ${output.offsetToIndices('global_idx')}; + ${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} + ${reduceOps.join('\n')}; + ${output.setByOffset('global_idx', 'sum')}; + }`; + }; + return { + name: 'Einsum', + shaderCache: { + hint: einsumEquation.equation, + inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims') + }, + getRunData: () => { + // The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The + // filter is added to make sure that dimValue is never 0. + const programUniformsInit: ProgramUniform[] = + uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol)) + .map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); + programUniformsInit.push({type: 'uint32', data: outputSize}); + const programUniforms: ProgramUniform[] = + inputShapes.filter((_, index) => enableInputShapesUniforms[index]) + .map((dims, _) => [...createTensorShapeVariables(dims)]) + .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); } - indices.forEach((index) => { - reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`); + return ({ + outputs: [{dims: outputShape, dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }); - reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`); - } - }); - reduceOpsLoopHeaders.push(`for(var ${symbol}: u32 = 0; ${symbol} < ${ - einsumEquation.symbolToInfo.get(symbol)?.dimValue}; ${symbol}++) {`); - reduceOpsLoopFooters.push('}'); - } - }); - const reduceOps = isReduceOpsWithoutLoop ? - [ - ...idxCopy, - `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};` - ] : - [ - ...idxCopy, - initSum, - ...reduceOpsLoopHeaders, - ...reduceOpsSetIndices, - initProd, - ...reduceOpCompute, - updateSum, - ...reduceOpsLoopFooters, - ]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(...inputVars, output)} - - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - var outputIndices = ${output.offsetToIndices('global_idx')}; - ${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} - ${reduceOps.join('\n')}; - ${output.setByOffset('global_idx', 'sum')}; - }`; - return { - name: 'Einsum', - shaderCache: {hint: einsumEquation.equation}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} - }), - getShaderSource, - }; -}; + }, + getShaderSource, + }; + }; export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => { const einsumEquation = new EinsumEquation(context.inputs, attributes.equation); - context.compute(createEinsumProgramInfo(context.inputs, einsumEquation)); + const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length)); + const outputShape = einsumEquation.outputDims; + const inputShapes = context.inputs.map((input, _) => input.dims); + context.compute(createEinsumProgramInfo( + enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); }; export const parseEinsumAttributes = (attributes: Record): EinsumAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 5680af4787b6a..3dc4e957e0fee 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -1,11 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -44,37 +45,66 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const inputShape = inputs[0].dims; const shape = Array.from(inputs[1].getBigInt64Array(), Number); const outputShape: number[] = calculateOutputShape(inputShape, shape); - const outputSize = ShapeUtil.size(outputShape); - const dataType = inputs[0].dataType; - const input = inputVariable('input', dataType, inputShape); - const output = outputVariable('output', dataType, outputShape); + const components = dataType === DataType.bool ? 4 : 1; + const outputSize = ShapeUtil.size(outputShape) / components; + + const enableInputShapeUniform = enableShapesUniforms(inputShape.length); + const enableOutputShapeUniform = enableShapesUniforms(outputShape.length); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const inputShape = ${input.indices(...inputShape)}; - ${shaderHelper.declareVariables(input, output)} - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - let outputIndices = ${output.offsetToIndices('global_idx')}; - var inputIndices: ${input.type.indices}; - for (var i = 0; i < ${inputShape.length}; i++) { - if (${input.indicesGet('inputShape', 'i')} == 1) { - ${input.indicesSet('inputIndices', 'i', 0)} - } else { - ${ - input.indicesSet( - 'inputIndices', 'i', output.indicesGet('outputIndices', `i + ${outputShape.length - inputShape.length}`))} - } + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape; + const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape; + const input = inputVariable('input', dataType, inputShapeOrRank, components); + const output = outputVariable('output', dataType, outputShapeOrRank, components); + let assignment: string; + if (dataType === DataType.bool) { + const singleAssignment = (resStr: string, x: number, typeCast = '') => ` + let outputIndices${x} = ${output.offsetToIndices(`outputOffset + ${x}u`)}; + let offset${x} = ${input.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let index${x} = offset${x} / 4u; + let component${x} = offset${x} % 4u; + ${resStr}[${x}] = ${typeCast}(${input.getByOffset(`index${x}`)}[component${x}]); + `; + assignment = ` + let outputOffset = global_idx * ${components}; + var data = vec4(0); + ${singleAssignment('data', 0, 'u32')} + ${singleAssignment('data', 1, 'u32')} + ${singleAssignment('data', 2, 'u32')} + ${singleAssignment('data', 3, 'u32')} + ${output.setByOffset('global_idx', 'data')} + }`; + } else { + assignment = ` + let outputIndices = ${output.offsetToIndices('global_idx')}; + let inputOffset = ${input.broadcastedIndicesToOffset('outputIndices', output)}; + ${output.setByOffset('global_idx', input.getByOffset('inputOffset'))} + }`; } - ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} - }`; + return ` + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} + ${assignment}`; + }; + + const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; + if (enableInputShapeUniform) { + programUniforms.push(...createTensorShapeVariables(inputShape)); + } + if (enableOutputShapeUniform) { + programUniforms.push(...createTensorShapeVariables(outputShape)); + } return { name: 'Expand', - shaderCache: {hint: `${outputShape}`}, + shaderCache: {hint: `${outputShape.length}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']}, getShaderSource, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index 9924a50e2ae6f..a945954adcaa4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -4,9 +4,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; export interface GatherElementsAttributes extends AttributeWithCacheKey { axis: number; @@ -32,65 +32,59 @@ const createGatherElementsProgramInfo = const inputShape = inputs[0].dims; const inputOutputDataType = inputs[0].dataType; const inputRank = inputShape.length; - const inputStrides = ShapeUtil.computeStrides(inputShape); - const inputSize = ShapeUtil.size(inputShape); const indicesShape = inputs[1].dims; const indicesDataType = inputs[1].dataType; - const indicesSize = ShapeUtil.size(indicesShape); - const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); const axisDimLimit = inputShape[axis]; const outputShape = indicesShape.slice(0); const outputSize = ShapeUtil.size(outputShape); - const input = inputVariable('input', inputOutputDataType, inputShape); - const indices = inputVariable('indices', indicesDataType, [indicesSize]); - const output = outputVariable('output', inputOutputDataType, outputShape); + const input = inputVariable('input', inputOutputDataType, inputRank); + const indices = inputVariable('indicesInput', indicesDataType, indicesShape.length); + const output = outputVariable('output', inputOutputDataType, outputShape.length); + + const programUniforms: ProgramUniform[] = + [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; + programUniforms.push(...createTensorShapeVariables(inputShape)); + programUniforms.push(...createTensorShapeVariables(indicesShape)); + programUniforms.push(...createTensorShapeVariables(outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits // That assumption is safe as it's not possible to allocate >2gb buffer for input tensor // Input data will be treated as u32 or two u32 for 8-byte tensors const getShaderSource = (shaderHelper: ShaderHelper) => ` - const inputStrides = array(${inputStrides.map(i => `${i}u`).join(',')}); - ${shaderHelper.declareVariables(input, indices, output)} + ${ + shaderHelper.registerUniform('outputSize', 'u32') + .registerUniform('axisDimLimit', 'i32') + .registerUniform('axis', 'u32') + .declareVariables(input, indices, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let outputIndices = ${output.offsetToIndices('global_idx')}; var idx = ${indices.getByOffset('global_idx')}; if (idx < 0) { - idx = idx + ${axisDimLimit}; - } - - var srcOffset = u32(0); - - for (var i = 0; i < ${inputShape.length}; i++) { - if (i == ${axis}) { - srcOffset += u32(idx) * inputStrides[i]; - } else { - srcOffset += ${output.indicesGet('outputIndices', 'i')} * inputStrides[i]; - } - } - - // Should never hit this with valid values in indices - // This is a guard against malicious data in the indices input - if (srcOffset < 0 || srcOffset >= ${inputSize}) { - return; + idx = idx + uniforms.axisDimLimit; } + var inputIndices = ${input.type.indices}(outputIndices); + ${input.indicesSet('inputIndices', 'uniforms.axis', 'u32(idx)')}; + let value = ${input.getByIndices('inputIndices')}; - output[global_idx] = input[srcOffset]; + ${output.setByOffset('global_idx', 'value')}; }`; return { name: 'GatherElements', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 5d6d6debadb9a..53ca094abfd62 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -29,7 +30,8 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath outputShape.splice(axis, 1, ...indicesShape); const axisDimLimit = inputShape[axis]; - const outputSize = ShapeUtil.size(outputShape); + const components = inputs[0].dataType === DataType.bool ? 4 : 1; + const outputSize = ShapeUtil.size(outputShape) / components; const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length); const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims; @@ -38,10 +40,6 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank); - const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank); - const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank); - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; if (enableInputShapesUniforms) { @@ -58,46 +56,75 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims'); inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims'); - const calcDataIndices = (): string => { - const indicesRank = indicesShape.length; - let calcStr = `var indicesIndices = ${indices.type.indices}(0);`; - for (let i = 0; i < indicesRank; i++) { - calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${ - outputShape.length > 1 ? `outputIndices[uniforms.axis + ${i}]` : 'outputIndices'};`; - } - calcStr += ` - var idx = ${indices.getByIndices('indicesIndices')}; - if (idx < 0) { - idx = idx + uniforms.axisDimLimit; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank, components); + const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank); + const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank, components); + + const calcDataIndices = (x: number|string): string => { + const indicesRank = indicesShape.length; + let calcStr = `var indicesIndices${x} = ${indices.type.indices}(0);`; + for (let i = 0; i < indicesRank; i++) { + calcStr += `${indicesRank > 1 ? `indicesIndices${x}[${i}]` : `indicesIndices${x}`} = ${ + outputShape.length > 1 ? `outputIndices${x}[uniforms.axis + ${i}]` : `outputIndices${x}`};`; + } + calcStr += ` + var idx${x} = ${indices.getByIndices(`indicesIndices${x}`)}; + if (idx${x} < 0) { + idx${x} = idx${x} + uniforms.axisDimLimit; + } + var dataIndices${x} = ${data.type.indices}(0); + `; + for (let i = 0, j = 0; i < inputRank; i++) { + if (i === axis) { + calcStr += `${inputRank > 1 ? `dataIndices${x}[${i}]` : `dataIndices${x}`} = u32(idx${x});`; + j += indicesRank; + } else { + calcStr += `${inputRank > 1 ? `dataIndices${x}[${i}]` : `dataIndices${x}`} = ${ + outputShape.length > 1 ? `outputIndices${x}[${j}]` : `outputIndices${x}`};`; + j++; } - var dataIndices = ${data.type.indices}(0); - `; - for (let i = 0, j = 0; i < inputRank; i++) { - if (i === axis) { - calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = u32(idx);`; - j += indicesRank; - } else { - calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = ${ - outputShape.length > 1 ? `outputIndices[${j}]` : 'outputIndices'};`; - j++; } + return calcStr; + }; + let assignment: string; + if (inputs[0].dataType === DataType.bool) { + const singleAssignment = (resStr: string, x: number, typeCast = '') => ` + let outputIndices${x} = ${output.offsetToIndices(`outputOffset + ${x}u`)}; + ${calcDataIndices(x)}; + let offset${x} = ${data.indicesToOffset(`dataIndices${x}`)}; + let index${x} = offset${x} / 4u; + let component${x} = offset${x} % 4u; + ${resStr}[${x}] = ${typeCast}(${data.getByOffset(`index${x}`)}[component${x}]); + `; + assignment = ` + let outputOffset = global_idx * ${components}; + var value = vec4(0); + ${singleAssignment('value', 0, 'u32')} + ${singleAssignment('value', 1, 'u32')} + ${singleAssignment('value', 2, 'u32')} + ${singleAssignment('value', 3, 'u32')} + ${output.setByOffset('global_idx', 'value')} + `; + } else { + assignment = ` + let outputIndices = ${output.offsetToIndices('global_idx')}; + ${calcDataIndices('')}; + let value = ${data.getByIndices('dataIndices')}; + ${output.setByOffset('global_idx', 'value')}; + `; } - return calcStr; - }; - - const getShaderSource = (shaderHelper: ShaderHelper) => ` + return ` ${ - shaderHelper.registerUniform('outputSize', 'u32') - .registerUniform('axisDimLimit', 'i32') - .registerUniform('axis', 'u32') - .declareVariables(data, indices, output)} + shaderHelper.registerUniform('outputSize', 'u32') + .registerUniform('axisDimLimit', 'i32') + .registerUniform('axis', 'u32') + .declareVariables(data, indices, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} - let outputIndices = ${output.offsetToIndices('global_idx')}; - ${calcDataIndices()}; - let value = ${data.getByIndices('dataIndices')}; - ${output.setByOffset('global_idx', 'value')}; + ${assignment} }`; + }; return { name: 'Gather', shaderCache: {hint: attributes.cacheKey, inputDependencies}, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 6e9dee41ce488..1c5d28e4b8e3f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -97,8 +97,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - let m = global_id.x / N; - let n = global_id.x % N; + let m = global_idx / N; + let n = global_idx % N; var value = ${dataType}(0); for (var k: u32 = 0u; k<${K}u; k++) { @@ -107,7 +107,7 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt ${calculateAlpha} ${calculateC} - output[global_id.x] = value; + output[global_idx] = value; }`; return { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 97f633c7cf47e..3a84844544c96 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; export interface InstanceNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -26,22 +26,25 @@ const createInstanceNormProgramInfo = const axis = 2; const normCount = ShapeUtil.sizeToDimension(xShape, axis); const normSize = ShapeUtil.sizeFromDimension(xShape, axis); + const components = getMaxComponents(normSize); + const normPackedSize = normSize / components; const C = xShape[1]; - const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normSize]); + const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components); const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); - const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normSize]); + const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components); const variables = [x, scale, bias, output]; const dataType = x.type.value; + const f32Type = components === 1 ? 'f32' : `vec${components}`; const workgroupSize = 64; const getShaderSource = (shaderHelper: ShaderHelper) => ` const C: u32 = ${C}; const normSize: u32 = ${normSize}; const epsilon: f32 = ${attributes.epsilon}; - var meanShared : ${dataType}; - var squaredNormShared : ${dataType}; - var workgroupShared : array<${dataType}, ${workgroupSize}>; + var meanShared : f32; + var squaredNormShared : f32; + var workgroupShared : array<${f32Type}, ${workgroupSize}>; const workgroupSize = ${workgroupSize}u; ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart(workgroupSize)} @@ -51,9 +54,9 @@ const createInstanceNormProgramInfo = let localIndex = local_id.x; // initialize workgroup memory - var initial: ${dataType} = 0; - for (var h = localIndex; h < normSize; h += workgroupSize) { - initial = initial + ${x.get('batch', 'channel', 'h')}; + var initial = ${f32Type}(0); + for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')}); } workgroupShared[localIndex] = initial; workgroupBarrier(); @@ -66,14 +69,14 @@ const createInstanceNormProgramInfo = workgroupBarrier(); } if (localIndex == 0) { - meanShared = workgroupShared[0] / ${dataType}(normSize); + meanShared = ${sumVector('workgroupShared[0]', components)} / f32(normSize); } workgroupBarrier(); // reinitialize workgroup memory. - initial = 0; - for (var h = localIndex; h < normSize; h += workgroupSize) { - let deviation = ${x.get('batch', 'channel', 'h')} - meanShared; + initial = ${f32Type}(0); + for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared); initial = initial + deviation * deviation; } workgroupShared[localIndex] = initial; @@ -87,15 +90,16 @@ const createInstanceNormProgramInfo = workgroupBarrier(); } if (localIndex == 0) { - squaredNormShared = workgroupShared[0]; + squaredNormShared = ${sumVector('workgroupShared[0]', components)}; } workgroupBarrier(); - let invStdDev = 1 / sqrt(squaredNormShared / ${dataType}(normSize) + epsilon); - let channelScale = invStdDev * ${scale.getByOffset('channel')}; - let channelShift = ${bias.getByOffset('channel')} - meanShared * channelScale; - for (var h = localIndex; h < normSize; h += workgroupSize) { - let value = ${x.get('batch', 'channel', 'h')} * channelScale + channelShift; + let invStdDev = 1 / sqrt(squaredNormShared / f32(normSize) + epsilon); + let channelScale = invStdDev * f32(${scale.getByOffset('channel')}); + let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale; + for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${ + f32Type}(channelShift)); ${output.set('batch', 'channel', 'h', 'value')}; } }`; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 19ca4ac5358ae..de9309d1e436f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -2,10 +2,150 @@ // Licensed under the MIT License. import {TensorView} from '../../tensor-view'; -import {BroadcastUtil} from '../../util'; -import {ComputeContext} from '../types'; +import {BroadcastUtil, ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; +import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common'; +import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; + +export const createNaiveMatmulProgramInfo = + (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], + reshapedOutputShape?: readonly number[], + isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { + const aShape = inputs[0].dims; + const bShape = inputs[1].dims; + + const M = aShape[aShape.length - 2]; + const N = bShape[bShape.length - 1]; + const K = aShape[aShape.length - 1]; + const components = getMaxComponents(N); + const aComponents = getMaxComponents(K); + const outputNumber = getMaxComponents(M); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; + const hasBias = inputs.length > 2; + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); + const batchSize = ShapeUtil.size(outerDims); + const outputShapeInShader = [batchSize, M, N]; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, + {type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), + ...createTensorShapeVariables(bShape) + ]; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + } + programUniforms.push(...createTensorShapeVariables(outputShapeInShader)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length); + const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); + const b = inputVariable('b', inputs[1].dataType, bShape.length, components); + const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); + const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); + const inputVariables = [a, b]; + let processBias = ''; + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + processBias = `${ + isChannelsLast ? `value += bias[col / ${biasComponents}];` : + `value += ${output.type.value}(bias[row + i]);`}`; + } + + const outerDimsA = aShape.slice(0, -2); + const outerDimsB = bShape.slice(0, -2); + const broadCastADims = getBroadcastDims(outerDimsA, outerDims); + const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); + const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { + const rank = variable.rank; + const name = variable.name; + if (rank === 2) { + return `var ${name}_indices = ${variable.type.indices}(0u, 0u);`; + } + const batchRank = batchDims.rank; + let resStr = `var ${name}_indices: ${variable.type.indices};`; + for (let i = rank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { + resStr += `\n${name}_indices[${i}] = ${batchRank > 1 ? `batch_indices[${j}]` : 'batch_indices'};`; + } + broadCastDims.forEach(i => { + resStr += `\n${name}_indices[${i}] = 0;`; + }); + resStr += `${name}_indices[${rank - 2}] = 0u; + ${name}_indices[${rank - 1}] = 0u;`; + return resStr; + }; + + const calcResult = (): string => { + let calcStr = `var a_data: ${a.type.value};`; + for (let i = 0; i < aComponents; i++) { + calcStr += ` + let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`; + } + for (let i = 0; i < outputNumber; i++) { + calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`; + + for (let j = 0; j < aComponents; j++) { + calcStr += ` + values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${ + i}]);\n`; + } + } + return calcStr; + }; + + return ` + ${ + shaderHelper.registerUniform('outputSize', 'u32') + .registerUniform('M', 'u32') + .registerUniform('N', 'u32') + .registerUniform('K', 'u32') + .registerInternalVariables(batchDims) + .declareVariables(...inputVariables, output)} + ${activationFunction} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + let col = (global_idx % (uniforms.N / ${components})) * ${components}; + var index1 = global_idx / (uniforms.N / ${components}); + let stride1 = uniforms.M / ${outputNumber}; + let row = (index1 % stride1) * ${outputNumber}; + let batch = index1 / stride1; + + ${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`} + ${getIndices(a, broadCastADims)} + let a_offset = ${a.indicesToOffset('a_indices')}; + ${getIndices(b, broadCastBDims)} + let b_offset = ${b.indicesToOffset('b_indices')}; + var values: array<${output.type.value}, ${outputNumber}>; + for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) { + ${calcResult()} + } + for (var i = 0u; i < ${outputNumber}u; i++) { + var value = values[i]; + ${processBias} + ${applyActivation} + let cur_indices = ${output.type.indices}(batch, row + i, col); + let offset = ${output.indicesToOffset('cur_indices')}; + ${output.setByOffset(`offset / ${components}`, 'value')}; + } + } + `; + }; + return { + name: 'MatMulNaive', + shaderCache: { + hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${ + isChannelsLast}`, + inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'] + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms + }), + getShaderSource + }; + }; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -23,5 +163,12 @@ export const matMul = (context: ComputeContext): void => { if (!outputShape) { throw new Error('Can\'t use matmul on the given tensors'); } - context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + const N = outputShape[outputShape.length - 1]; + const K = context.inputs[0].dims[context.inputs[0].dims.length - 1]; + if (N < 8 && K < 8) { + context.compute( + createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + } else { + context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts new file mode 100644 index 0000000000000..b7726a36bcaad --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -0,0 +1,335 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType} from '../types'; + +import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {createTransposeProgramInfo, TransposeAttributes} from './transpose'; + +const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + const query = inputs[0]; + const key = inputs[1]; + const value = inputs[2]; + const bias = inputs[3]; + const keyPaddingMask = inputs[4]; + const relativePositionBias = inputs[5]; + const pastKey = inputs[6]; + const pastValue = inputs[7]; + + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None + // relative_position_bias : (B, 1, S, L) + // past_key : (B, N, S*, H) + // past_value : (B, N, S*, H) + // When no packing for q/k/v: + // query (Q) : (B, S, D) + // key (K) : (B, L, D) or (B, N, S*, H) + // value (V) : (B, L, D_v) or (B, N, S*, H) + // bias (Q/K/V) : (D + D + D_v) + // When packed kv is used: + // query (Q) : (B, S, D) + // key (K) : (B, L, N, 2, H) + // value (V) : None + // bias (Q/K/V) : None + // When packed qkv is used: + // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) + // key (K) : None + // value (V) : None + // bias (Q/K/V) : None or (D + D + D_v) + + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input query is expected to have 3 or 5 dimensions'); + } + + const dmmhaPacking = false; + const batchSize = query.dims[0]; + const sequenceLength = query.dims[1]; + const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : + attributes.numHeads * query.dims[4]; + let kvSequenceLength = sequenceLength; + + let pastSequenceLength = 0; + let maxSequenceLength = 0; + const headSize = Math.floor(hiddenSize / attributes.numHeads); + if (pastKey && pastValue) { + if (pastKey.dims.length !== 4) { + throw new Error('Input "past_key" is expected to have 4 dimensions'); + } + if (pastValue.dims.length !== 4) { + throw new Error('Input "past_value" is expected to have 4 dimensions'); + } + pastSequenceLength = pastKey.dims[2]; + maxSequenceLength = pastKey.dims[2]; + } else if (pastKey || pastValue) { + throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); + } + + let qkvFormat: AttentionQkvFormat; + if (key) { + if (query.dims.length !== 3) { + throw new Error('Input "query" is expected to have 3 dimensions when key is given'); + } + if (key.dims.length < 3 || key.dims.length > 5) { + throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions'); + } + if (query.dims[0] !== key.dims[0]) { + throw new Error('Input "query" and "key" shall have same dim 0 (batch size)'); + } + + if (key.dims.length === 3) { + if (key.dims[2] !== query.dims[2]) { + throw new Error('Input "query" and "key" shall have same dim 2 (hidden_size)'); + } + qkvFormat = AttentionQkvFormat.qkvBSNH; + kvSequenceLength = key.dims[1]; + } else if (key.dims.length === 5) { + if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { + throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv'); + } + if (value) { + throw new Error('Expect "value" be none when "key" has packed kv format.'); + } + qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; + kvSequenceLength = key.dims[1]; + } else { // key_dims.size() == 4 (cross-attention with past_key) + if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { + throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); + } + + qkvFormat = AttentionQkvFormat.unknown; + kvSequenceLength = key.dims[2]; + } + } else { // packed QKV + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); + } + if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) { + throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); + } + + qkvFormat = AttentionQkvFormat.qkvBSN3H; + } + + if (bias) { + if (bias.dims.length !== 1) { + throw new Error('Input "bias" is expected to have 1 dimension'); + } + + if (value) { + if (query.dims.length === 5 && query.dims[3] === 2) { + throw new Error('bias is not allowed for packed kv.'); + } + } + } + + let maskType: AttentionMaskType = AttentionMaskType.none; + if (keyPaddingMask) { + maskType = AttentionMaskType.maskUnknown; + const maskDims = keyPaddingMask.dims; + if (maskDims.length === 1) { + if (maskDims[0] === batchSize) { + maskType = AttentionMaskType.mask1dKeySeqLen; + } else if (maskDims[0] === 3 * batchSize + 2) { + maskType = AttentionMaskType.mask1DKeySeqLenStart; + } + } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) { + maskType = AttentionMaskType.mask2dKeyPadding; + } + if (maskType === AttentionMaskType.maskUnknown) { + throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, kv_sequence_length)'); + } + throw new Error('Mask not supported'); + } + + let passPastInKv = false; + let vHiddenSize = hiddenSize; + if (value) { + if (value.dims.length !== 3 && value.dims.length !== 4) { + throw new Error('Input "value" is expected to have 3 or 4 dimensions'); + } + + if (query.dims[0] !== value.dims[0]) { + throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)'); + } + + if (value.dims.length === 3) { + if (kvSequenceLength !== value.dims[1]) { + throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)'); + } + vHiddenSize = value.dims[2]; + } else { + if (kvSequenceLength !== value.dims[2]) { + throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)'); + } + vHiddenSize = value.dims[1] * value.dims[3]; + passPastInKv = true; + } + } + + const totalSequenceLength = pastSequenceLength + kvSequenceLength; + const broadcastResPosBias = false; + // if (extraAddQk) { + // if (extraAddQk.dims[0] === 1) { + // broadcastResPosBias = true; + // } + // } + + if (keyPaddingMask) { + throw new Error('Key padding mask is not supported'); + } + if (relativePositionBias) { + throw new Error('extraAddQk is not supported'); + } + if (pastKey) { + throw new Error('pastKey is not supported'); + } + if (pastValue) { + throw new Error('pastValue is not supported'); + } + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize: 0, + hiddenSize, + vHiddenSize, + headSize, + vHeadSize: Math.floor(vHiddenSize / attributes.numHeads), + numHeads: attributes.numHeads, + isUnidirectional: false, + pastPresentShareBuffer: false, + maskFilterValue: attributes.maskFilterValue, + maskType, + scale: attributes.scale, + broadcastResPosBias, + passPastInKv, + qkvFormat, + }; +}; + + +export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]}); + +const addBiasTranspose = + (context: ComputeContext, qkv: TensorView, bias: TensorView, batchSize: number, sequenceLength: number, + hiddenSize: number, biasOffset: number) => { + const outputShape = [batchSize, sequenceLength, hiddenSize]; + const outputSize = ShapeUtil.size(outputShape); + + const dataType = tensorTypeToWsglStorageType(qkv.dataType); + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const biasOffset = ${biasOffset}u; + const hiddenSize = ${hiddenSize}u; + + @group(0) @binding(0) var qkv: array<${dataType}>; + @group(0) @binding(1) var bias: array<${dataType}>; + @group(0) @binding(2) var qkv_with_bias: array<${dataType}>; + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let biasOffsetIdx = (global_idx % hiddenSize) + biasOffset; + + qkv_with_bias[global_idx] = qkv[global_idx] + bias[biasOffsetIdx]; + }`; + + return context.compute( + { + name: 'MultiHeadAttentionAddBias', + shaderCache: {hint: JSON.stringify({batchSize, sequenceLength, hiddenSize, biasOffset})}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + }), + getShaderSource, + }, + {inputs: [qkv, bias], outputs: [-1]})[0]; + }; + +const maybeTransposeToBNSHAndAddBias = + (context: ComputeContext, batchSize: number, numHeads: number, sequenceLength: number, headSize: number, + input: TensorView, bias?: TensorView, biasOffset?: number) => { + // const newDims = []; + + let reshapedInput = input; + if (!bias) { + if (input.dims.length === 3) { + reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); + } + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + } else { + if (sequenceLength === 1) { + throw new Error('AddBiasReshape is not implemented. Please export your model with packed QKV or KV'); + } else { + reshapedInput = + addBiasTranspose(context, input, bias, batchSize, sequenceLength, numHeads * headSize, biasOffset!); + reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + } + } + }; + +export const multiHeadAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateInputs(context.inputs, attributes); + + if (context.inputs[0].dims.length === 5) { + throw new Error('Packed QKV is not implemented'); + } + + if (context.inputs[1]?.dims.length === 5) { + throw new Error('Packed KV is not implemented'); + } + + // applyAttention expects BNSH inputs + const kvBNSH = context.inputs[1] && context.inputs[2] && context.inputs[1].dims.length === 4 && + context.inputs[2].dims.length === 4; + + const Q = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, context.inputs[0], + context.inputs[3], 0); + + if (kvBNSH) { + return applyAttention( + context, Q, context.inputs[1], context.inputs[2], context.inputs[4], undefined, undefined, undefined, + context.inputs[5], params, attributes); + } + + const K = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.kvSequenceLength, params.headSize, context.inputs[1], + context.inputs[3], params.hiddenSize); + + const V = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.kvSequenceLength, params.vHeadSize, context.inputs[2], + context.inputs[3], 2 * params.hiddenSize); + + applyAttention( + context, Q, K, V, context.inputs[4], undefined, context.inputs[6], context.inputs[7], context.inputs[5], params, + attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 1538644412afd..84d04efc37f28 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -1,12 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {env} from 'onnxruntime-common'; + import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; // TODO: support: // - ceil_mode "test_maxpool_2d_ceil" @@ -15,12 +17,9 @@ import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './comm // - [MaxPool] output[1] "test_maxpool_with_argmax_2d_precomputed_pads" const validateInputs = (inputs: readonly TensorView[]): void => { - if (!inputs || inputs.length !== 1) { + if (env.webgpu.validateInputContent && (!inputs || inputs.length !== 1)) { throw new Error('Pool ops requires 1 input.'); } - if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) { - throw new Error('Pool ops supports 1-D or 2-D inputs only for now.'); - } }; const getAdjustedPoolAttributesAndOutputShape = ( @@ -51,30 +50,83 @@ const getAdjustedPoolAttributesAndOutputShape = ( - shaderHelper: ShaderHelper, x: IndicesHelper, xShape: readonly number[], outputShape: readonly number[], - attributes: AttributeType, op1: string, op2: string, start: string): string => { +const getUniformAndPadInfo = ( + outputShape: readonly number[], + attributes: AttributeType): [ProgramUniform[], UniformsArrayType, boolean, boolean, boolean] => { const isChannelsLast = attributes.format === 'NHWC'; - const inputDims = xShape; - const dataType = x.type.value; - const rank = inputDims.length; const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', x.type.tensor, outputShape); - + const kernelSize = ShapeUtil.size(attributes.kernelShape); + const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}, {type: 'uint32', data: kernelSize}]; + const uniforms: UniformsArrayType = [{name: 'outputSize', type: 'u32'}, {name: 'kernelSize', type: 'u32'}]; if (attributes.kernelShape.length <= 2) { const kw = attributes.kernelShape[attributes.kernelShape.length - 1]; const sw = attributes.strides[attributes.strides.length - 1]; const pwStart = attributes.pads[attributes.pads.length / 2 - 1]; const pwEnd = attributes.pads[attributes.pads.length - 1]; - const dimIdxW = rank - (isChannelsLast ? 2 : 1); + const pwStartEnd = !!(pwStart + pwEnd); + programUniforms.push( + {type: 'uint32', data: kw}, + {type: 'uint32', data: sw}, + {type: 'uint32', data: pwStart}, + {type: 'uint32', data: pwEnd}, + ); + uniforms.push( + {name: 'kw', type: 'u32'}, {name: 'sw', type: 'u32'}, {name: 'pwStart', type: 'u32'}, + {name: 'pwEnd', type: 'u32'}); + + let phStartEnd = false; + if (attributes.kernelShape.length === 2) { + const kh = attributes.kernelShape[attributes.kernelShape.length - 2]; + const sh = attributes.strides[attributes.strides.length - 2]; + const phStart = attributes.pads[attributes.pads.length / 2 - 2]; + const phEnd = attributes.pads[attributes.pads.length - 2]; + phStartEnd = !!(phStart + phEnd); + programUniforms.push( + {type: 'uint32', data: kh}, {type: 'uint32', data: sh}, {type: 'uint32', data: phStart}, + {type: 'uint32', data: phEnd}); + + uniforms.push( + {name: 'kh', type: 'u32'}, {name: 'sh', type: 'u32'}, {name: 'phStart', type: 'u32'}, + {name: 'phEnd', type: 'u32'}); + } + return [programUniforms, uniforms, true, pwStartEnd, phStartEnd]; + } else { + if (isChannelsLast) { + throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.'); + } + const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape); + programUniforms.push( + {type: 'uint32', data: kernelStrides}, {type: 'uint32', data: attributes.pads}, + {type: 'uint32', data: attributes.strides}); + uniforms.push( + {name: 'kernelStrides', type: 'u32', length: kernelStrides.length}, + {name: 'pads', type: 'u32', length: attributes.pads.length}, + {name: 'strides', type: 'u32', length: attributes.strides.length}); + + const hasPads = attributes.pads.reduce((sum, cur) => sum + cur); + return [programUniforms, uniforms, !!hasPads, false, false]; + } +}; + +const generatePoolingCode = ( + shaderHelper: ShaderHelper, x: IndicesHelper, rank: number, outputShapeRank: number, attributes: AttributeType, + op1: string, op2: string, start: number, uniforms: UniformsArrayType, hasPads: boolean, pwStartEnd: boolean, + phStartEnd: boolean): string => { + const isChannelsLast = attributes.format === 'NHWC'; + const dataType = x.type.value; + const output = outputVariable('output', x.type.tensor, outputShapeRank); + + if (attributes.kernelShape.length <= 2) { let codeW = ''; let codeH = ''; let codeHEnd = ''; - if (pwStart + pwEnd !== 0) { + const dimIdxW = rank - (isChannelsLast ? 2 : 1); + if (pwStartEnd === true) { codeW = ` - for (var i: u32 = 0u; i < ${kw}u; i++) { - xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; - if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) { + for (var i: u32 = 0u; i < uniforms.kw; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i; + if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] + >= uniforms.x_shape[${dimIdxW}]) { pad++; continue; } @@ -83,33 +135,28 @@ const generatePoolingCode = = ${dimH}) { - pad+= ${kw}; + for (var j: u32 = 0u; j < uniforms.kh; j++) { + xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j; + if (xIndices[${dimIdxH}] < 0 || xIndices[${dimIdxH}] >= uniforms.x_shape[${dimIdxH}]) { + pad += i32(uniforms.kw); continue; } `; } else { codeH = ` - for (var j: u32 = 0u; j < ${kh}u; j++) { - xIndices[${dimIdxH}] = indices[${dimIdxH}] * ${sh} - ${phStart} + j; + for (var j: u32 = 0u; j < uniforms.kh; j++) { + xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j; `; } codeHEnd = ` @@ -118,15 +165,15 @@ const generatePoolingCode = 2 is not supported for NHWC format.'); } - const kernelSize = ShapeUtil.size(attributes.kernelShape); - const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape); - const stridesRank = kernelStrides.length; + const stridesRank = attributes.kernelShape.length; const padsRank = attributes.pads.length; - const hasPads = attributes.pads.reduce((sum, cur) => sum + cur); let padCode = ''; if (hasPads) { padCode = ` - if (xIndices[j] >= inputDims[j]) { + if (xIndices[j] >= uniforms.x_shape[j]) { pad++; isPad = true; break; @@ -166,37 +210,32 @@ const generatePoolingCode = (${attributes.pads.map(i => `${i}u`).join(',')}); - const inputDims = array(${inputDims.map(i => `${i}u`).join(',')}); - const kernelStrides = array(${kernelStrides.map(i => `${i}u`).join(',')}); - const strides = array(${attributes.strides.map(i => `${i}u`).join(',')}); + ${shaderHelper.registerUniforms(uniforms).declareVariables(x, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let indices = ${output.offsetToIndices('global_idx')}; - let xIndices = ${output.offsetToIndices('global_idx')}; + var xIndices = ${output.offsetToIndices('global_idx')}; var offsets: array; - var value = ${output.type.value}(${start}); + var value = ${dataType}(${start}); var pad = 0; var isPad = false; - for (var i: u32 = 0u; i < ${kernelSize}u; i++) { + for (var i: u32 = 0u; i < uniforms.kernelSize; i++) { var offset = i; for (var j = 0u; j < ${stridesRank - 1}u; j++) { - offsets[j] = offset / kernelStrides[j]; - offset -= offsets[j] * kernelStrides[j]; + offsets[j] = offset / ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)}; + offset -= offsets[j] * ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)}; } offsets[${stridesRank - 1}] = offset; isPad = false; for (var j = ${rank - stridesRank}u; j < ${rank}u; j++) { - xIndices[j] = indices[j] * strides[j - ${rank - stridesRank}u] - + offsets[j - ${rank - stridesRank}u] - pads[j - 2u]; + xIndices[j] = indices[j] * ${ + getElementAt('uniforms.strides', `j - ${rank - stridesRank}u`, stridesRank)} + + offsets[j - ${rank - stridesRank}u] - ${getElementAt('uniforms.pads', 'j - 2u', padsRank)}; ${padCode} } ${op2} @@ -236,27 +275,35 @@ const createAveragePoolProgramInfo = (name: string, input: TensorView, isGlobalOperator: boolean, attributes: AveragePoolAttributes): ProgramInfo => { const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape(input, attributes, isGlobalOperator); - const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape); - - const x = inputVariable('x', input.dataType, input.dims); + const x = inputVariable('x', input.dataType, input.dims.length); const dataType = x.type.value; const op1 = 'value += x_val;'; let op2 = ''; if (adjustedAttributes.countIncludePad) { - op2 += `value /= ${dataType}(${kernelSize});`; + op2 += `value /= ${dataType}(uniforms.kernelSize);`; } else { - op2 += `value /= ${dataType}(${kernelSize} - pad);`; + op2 += `value /= ${dataType}(i32(uniforms.kernelSize) - pad);`; } + const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] = + getUniformAndPadInfo(outputShape, adjustedAttributes); + programUniforms.push(...createTensorShapeVariables(input.dims)); + programUniforms.push(...createTensorShapeVariables(outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; return { name, - shaderCache: {hint: attributes.cacheKey}, + shaderCache: { + hint: attributes.cacheKey + hasPads + pwStartEnd + phStartEnd + adjustedAttributes.countIncludePad, + inputDependencies + }, getRunData: () => ({ outputs: [{dims: outputShape, dataType: input.dataType}], - dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}, + programUniforms }), - getShaderSource: shaderHelper => - generatePoolingCode(shaderHelper, x, input.dims, outputShape, adjustedAttributes, op1, op2, '0.0'), + getShaderSource: shaderHelper => generatePoolingCode( + shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, 0.0, uniforms, + hasPads, pwStartEnd, phStartEnd), }; }; @@ -312,16 +359,23 @@ const createMaxPoolProgramInfo = value = max(x_val, value); `; const op2 = ''; - const x = inputVariable('x', input.dataType, input.dims); + const x = inputVariable('x', input.dataType, input.dims.length); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; + const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] = + getUniformAndPadInfo(outputShape, adjustedAttributes); + programUniforms.push(...createTensorShapeVariables(input.dims)); + programUniforms.push(...createTensorShapeVariables(outputShape)); return { name, - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: attributes.cacheKey + hasPads, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: input.dataType}], - dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}, + programUniforms }), - getShaderSource: shaderHelper => - generatePoolingCode(shaderHelper, x, input.dims, outputShape, adjustedAttributes, op1, op2, '-1e5'), + getShaderSource: shaderHelper => generatePoolingCode( + shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, -1e5, uniforms, + hasPads, pwStartEnd, phStartEnd), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts index 1365d1e9a12a4..7c440cbffea7b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -141,7 +141,6 @@ export const createReduceSharedProgramInfo = return ((a - 1u) / b + 1u); } ${shaderHelper.mainStart(workgroupSize)} - let local_idx = local_id.x; let outputIndex = global_idx / ${workgroupSize}; let offset = outputIndex * uniforms.reduceSize; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index b5c956e57a9b1..e8851ac546942 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; import {reduceL1Shared, reduceL2Shared, reduceLogSumExpShared, reduceLogSumShared, reduceMaxShared, reduceMeanShared, reduceMinShared, reduceProdShared, reduceSumShared, reduceSumSquareShared} from './reduce-shared'; const validateInputs = (inputs: readonly TensorView[]): void => { @@ -30,14 +30,14 @@ export type ReduceOp = (input: IndicesHelper, output: IndicesHelper, axes: readonly number[]) => [string, string, string, string, ...string[]]; -const noOp: ReduceOp = (input) => ['', '', `var value = ${input.getByOffset('inputOffset')};`, '']; +const noOp: ReduceOp = (input) => ['', '', `var value = ${input.getByIndices('input_indices')};`, '']; export const createReduceProgramInfo = (name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceOp: ReduceOp, axesInput: number[], outputDataType: DataType, keepDims = false, noopWithEmptyAxes = false): ProgramInfo => { const outputShape: number[] = []; const inputShape = inputs[0].dims; - - const axes = ShapeUtil.normalizeAxes(axesInput, inputs[0].dims.length); + const inputRank = inputShape.length; + const axes = ShapeUtil.normalizeAxes(axesInput, inputRank); const reduceOnAllAxes = !noopWithEmptyAxes && axes.length === 0; inputShape.forEach((d, i) => { if (reduceOnAllAxes || axes.indexOf(i) >= 0) { @@ -48,53 +48,50 @@ export const createReduceProgramInfo = outputShape.push(d); } }); - - const idxCopy: string[] = []; // copy output indexes to input indexes - - const input = inputVariable('_A', inputs[0].dataType, inputShape); - const output = outputVariable('output', outputDataType, outputShape); - const ops = reduceOp(input, output, axes); - const inputOffsetAssignment = `inputOffset = ${input.indicesToOffset('inputIndices')};`; - const initinputOffsetLet = `let ${inputOffsetAssignment};`; - const initinputOffsetVar = `var ${inputOffsetAssignment};`; - const initinputOffset = (ops[1] === '') ? '' : initinputOffsetVar; - let reduceOps = ((ops[1] === '') ? initinputOffsetLet : inputOffsetAssignment) + '\n' + ops[2]; - - for (let k = 0, l = 0; k < inputs[0].dims.length; k++) { - // if this axis is reduced - if (reduceOnAllAxes || axes.indexOf(k) >= 0) { - if (keepDims) { + const outputRank = outputShape.length; + const outputSize = ShapeUtil.size(outputShape); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const idxCopy: string[] = []; // copy output indexes to input indexes + + const input = inputVariable('_A', inputs[0].dataType, inputRank); + const output = outputVariable('output', outputDataType, outputRank); + const ops = reduceOp(input, output, axes); + let reduceOps = ops[2]; + + for (let k = 0, l = 0; k < inputRank; k++) { + // if this axis is reduced + if (reduceOnAllAxes || axes.indexOf(k) >= 0) { + if (keepDims) { + l++; + } + // loop over the d-th axis + reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputShape[k]}; j${k}++) { + ${ops[2].includes('last_index') ? `let last_index = j${k};` : ''} + ${input.indicesSet('input_indices', k, `j${k}`)} + ${reduceOps} + }`; + } else { + idxCopy.push(`${input.indicesSet('input_indices', k, output.indicesGet('output_indices', l))};`); l++; } - // loop over the d-th axis - reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) { - ${ops[2].includes('lastIndex') ? `let lastIndex = j${k};` : ''} - ${input.indicesSet('inputIndices', k, `j${k}`)} - ${reduceOps} - }`; - } else { - idxCopy.push(`${input.indicesSet('inputIndices', k, output.indicesGet('outputIndices', l))};`); - l++; } - } + return ` - const outputSize = ShapeUtil.size(outputShape); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, output)} + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - var inputIndices: ${input.type.indices}; - let outputIndices = ${output.offsetToIndices('global_idx')}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var input_indices: ${input.type.indices}; + let output_indices = ${output.offsetToIndices('global_idx')}; ${idxCopy.join('\n')} ${ops[0]} // init ops for reduce max/min - ${initinputOffset} ${ops[1]} ${reduceOps} ${ops[3]} ${ops.length === 4 ? output.setByOffset('global_idx', 'value') : ops.slice(4).join('\n')} }`; + }; return { name, @@ -102,7 +99,11 @@ export const createReduceProgramInfo = getShaderSource, getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: [ + {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), + ...createTensorShapeVariables(outputShape) + ] }), }; }; @@ -125,7 +126,7 @@ const runReduceProgram = context.compute( createReduceProgramInfo( - name, {hint: updatedAttributes.cacheKey}, [inputs[0]], + name, {hint: updatedAttributes.cacheKey, inputDependencies: ['rank']}, [inputs[0]], updatedAttributes.noopWithEmptyAxes && updatedAttributes.axes.length === 0 ? noOp : reduceOp, updatedAttributes.axes, inputs[0].dataType, updatedAttributes.keepDims, updatedAttributes.noopWithEmptyAxes), @@ -137,7 +138,7 @@ const reduceLogSumNaive = (context: ComputeContext, attributes: ReduceAttributes const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, '', - `value += ${input.getByOffset('inputOffset')};`, + `value += ${input.getByIndices('input_indices')};`, 'value = log(value);', ]; runReduceProgram(context, 'ReduceLogSum', attributes, reduceOp); @@ -148,7 +149,7 @@ const reduceL1Naive = (context: ComputeContext, attributes: ReduceAttributes): v const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, '', - `value += abs(${input.getByOffset('inputOffset')});`, + `value += abs(${input.getByIndices('input_indices')});`, '', ]; runReduceProgram(context, 'ReduceL1', attributes, reduceOp); @@ -159,7 +160,7 @@ const reduceL2Naive = (context: ComputeContext, attributes: ReduceAttributes): v const reduceOp: ReduceOp = (input, output) => [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, '', - `t = ${input.getByOffset('inputOffset')}; value += (t * t);`, + `t = ${input.getByIndices('input_indices')}; value += (t * t);`, 'value = sqrt(value);', ]; runReduceProgram(context, 'ReduceL2', attributes, reduceOp); @@ -170,7 +171,7 @@ const reduceLogSumExpNaive = (context: ComputeContext, attributes: ReduceAttribu const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, '', - `value += exp(${input.getByOffset('inputOffset')});`, + `value += exp(${input.getByIndices('input_indices')});`, 'value = log(value);', ]; runReduceProgram(context, 'ReduceLogSumExp', attributes, reduceOp); @@ -182,14 +183,14 @@ const reduceMaxNaive = (context: ComputeContext, attributes: ReduceAttributes): const idxZero = []; for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(input.indicesSet('inputIndices', k, 0)); + idxZero.push(input.indicesSet('input_indices', k, 0)); } } return [ `${idxZero.join('\n')}`, - `var value = ${input.getByOffset('inputOffset')};`, - `value = max(value, ${input.getByOffset('inputOffset')});`, + `var value = ${input.getByIndices('input_indices')};`, + `value = max(value, ${input.getByIndices('input_indices')});`, '', ]; }; @@ -210,7 +211,7 @@ const reduceMeanNaive = (context: ComputeContext, attributes: ReduceAttributes): return [ 'var sum = f32(0);', '', - `sum += f32(${input.getByOffset('inputOffset')});`, + `sum += f32(${input.getByIndices('input_indices')});`, `let value = ${output.type.value}(sum / ${size});`, ]; }; @@ -223,14 +224,14 @@ const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes): const idxZero = []; for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(`inputIndices[${k}] = 0;`); // first element + idxZero.push(`input_indices[${k}] = 0;`); // first element } } return [ `${idxZero.join('\n')}`, - `var value = ${input.getByOffset('inputOffset')};`, - `value = min(value, ${input.getByOffset('inputOffset')});`, + `var value = ${input.getByIndices('input_indices')};`, + `value = min(value, ${input.getByIndices('input_indices')});`, '', ]; }; @@ -242,7 +243,7 @@ const reduceProdNaive = (context: ComputeContext, attributes: ReduceAttributes): const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(1);`, '', - `value *= ${input.getByOffset('inputOffset')};`, + `value *= ${input.getByIndices('input_indices')};`, '', ]; runReduceProgram(context, 'ReduceProd', attributes, reduceOp); @@ -253,7 +254,7 @@ const reduceSumNaive = (context: ComputeContext, attributes: ReduceAttributes): const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, '', - `value += ${input.getByOffset('inputOffset')};`, + `value += ${input.getByIndices('input_indices')};`, '', ]; runReduceProgram(context, 'ReduceSum', attributes, reduceOp); @@ -264,7 +265,7 @@ const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttribu const reduceOp: ReduceOp = (input, output) => [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, '', - `t = ${input.getByOffset('inputOffset')}; value += t * t;`, + `t = ${input.getByIndices('input_indices')}; value += t * t;`, '', ]; runReduceProgram(context, 'ReduceSumSquare', attributes, reduceOp); @@ -273,7 +274,7 @@ const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttribu const useNaiveReduceMethod = (shape: readonly number[], axes: readonly number[], noopWithEmptyAxes: boolean): boolean => { if (axes.length === 0) { - return noopWithEmptyAxes ? true : false; + return noopWithEmptyAxes; } let outputSize = 1; @@ -289,7 +290,7 @@ const useNaiveReduceMethod = // The condition data is very rough, although considering the count of Execution Unit (EU), the potential // work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments // on some machines. - return reduceSize < 32 && outputSize > 1024 ? true : false; + return reduceSize < 32 && outputSize > 1024; }; export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => { @@ -371,6 +372,3 @@ export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttribut reduceLogSumShared(context, attributes); } }; - -export const parseReduceAttributes = (attributes: Record): ReduceAttributes => - createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 9869561a36251..e1369c2c2b43b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; type CoordinateTransformMode = 'half_pixel'|'asymmetric'|'pytorch_half_pixel'|'tf_half_pixel_for_nn'|'align_corners'| 'tf_crop_and_resize'|'half_pixel_symmetric'; @@ -105,50 +105,51 @@ const validateInputs = } }; -const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode): string => - 'fn getOriginalCoordinateFromResizedCoordinate(xResized: f32, xScale: f32, lengthResized: f32,\ - lengthOriginal: f32, roiStart: f32, roiEnd: f32) -> f32 { ' + +const getOriginalCoordinateFromResizedCoordinate = + (coordinateTransferMode: CoordinateTransformMode, dType: string): string => + `fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType}, + lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` + (() => { - switch (coordinateTransferMode) { - case 'asymmetric': - return 'return xResized / xScale;'; - case 'pytorch_half_pixel': - return 'if (lengthResized > 1) { \ + switch (coordinateTransferMode) { + case 'asymmetric': + return 'return xResized / xScale;'; + case 'pytorch_half_pixel': + return 'if (lengthResized > 1) { \ return (xResized + 0.5) / xScale - 0.5; \ } else { \ return 0.0; \ }'; - case 'tf_half_pixel_for_nn': - return 'return (xResized + 0.5) / xScale;'; - case 'align_corners': - return 'if (lengthResized == 1) { \ + case 'tf_half_pixel_for_nn': + return 'return (xResized + 0.5) / xScale;'; + case 'align_corners': + return 'if (lengthResized == 1) { \ return 0.0; \ } else { \ return xResized * (lengthOriginal - 1) / (lengthResized - 1); \ }'; - case 'tf_crop_and_resize': - return 'if (lengthResized > 1) { \ + case 'tf_crop_and_resize': + return `if (lengthResized > 1) { \ return roiStart * (lengthOriginal - 1) + \ (xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \ } else { \ - return 0.5 * (roiStart + roiEnd) * f32(lengthOriginal - 1); \ - }'; - case 'half_pixel_symmetric': - return [ - 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;', - 'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);', - 'return offset + ((xResized + 0.5) / xScale) - 0.5;' - ].join('\n'); - case 'half_pixel': - return 'return ((xResized + 0.5) / xScale) - 0.5;'; - default: - throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); - } - })() + + return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \ + }`; + case 'half_pixel_symmetric': + return [ + 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;', + 'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);', + 'return offset + ((xResized + 0.5) / xScale) - 0.5;' + ].join('\n'); + case 'half_pixel': + return 'return ((xResized + 0.5) / xScale) - 0.5;'; + default: + throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); + } + })() + '}'; -const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number): string => - 'fn getNearestPixelFromOriginal(xOriginal: f32, isDownSample: bool) -> f32 {' + (() => { +const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number, dType: string): string => + `fn getNearestPixelFromOriginal(xOriginal: ${dType}, isDownSample: bool) -> ${dType} {` + (() => { switch (nearestMode) { case 'round_prefer_ceil': return 'if (fract(xOriginal) == 0.5) { \ @@ -244,67 +245,67 @@ const adjustOutputShape = (inputShape: readonly number[], scales: number[], attr }; const calculateOriginalIndicesFromOutputIndices = - (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], - roi: readonly number[]): string => ` - fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array { - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - const outputShape = array(${outputShape.map(i => `${i}u`).join(',')}); - const scales = array(${scales.map(i => `${i}f`).join(',')}); - const roi = array(${roi.map(i => `${i}f`).join(',')}); - var originalIndices: array; + (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scalesLength: number, + roiLength: number): string => ` + fn calculateOriginalIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> array<${ + output.type.value}, ${outputShape.length}> { + var original_indices: array<${output.type.value}, ${outputShape.length}>; for (var i:u32 = 0; i < ${outputShape.length}; i++) { - var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; - if (scales[i] == 1.0) { - originalIndices[i] = f32(outputIndex); + var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')}); + var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)}; + var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)}; + var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)}; + if (scale == 1.0) { + original_indices[i] = output_index; } else { - originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i], - f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]); + var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)}); + var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)}); + original_indices[i] = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i, + input_shape_i, roi_low, roi_hi); } } - return originalIndices; + return original_indices; }`; const calculateInputIndicesFromOutputIndices = (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], - scales: readonly number[], roi: readonly number[], useExtrapolation: boolean): string => ` - fn calculateInputIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - const outputShape = array(${outputShape.map(i => `${i}u`).join(',')}); - const scales = array(${scales.map(i => `${i}f`).join(',')}); - const roi = array(${roi.map(i => `${i}f`).join(',')}); - var inputIndices: ${input.type.indices}; - for (var i:u32 = 0; i < ${outputShape.length}; i++) { - var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; - var inputIndex: u32; - if (scales[i] == 1.0) { - inputIndex = outputIndex; - } else { - var original_idx = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i], - f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]); - if (!${useExtrapolation} || (original_idx >= 0 && original_idx < f32(inputShape[i]))) { - if (original_idx < 0) { - inputIndex = 0; - } else if (original_idx > (f32(inputShape[i]) - 1)) { - inputIndex = inputShape[i] - 1; - } else { - inputIndex = u32(getNearestPixelFromOriginal(original_idx, scales[i] < 1)); - } + scalesLength: number, roiLength: number, useExtrapolation: boolean): string => ` + fn calculateInputIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { + var input_indices: ${input.type.indices}; + for (var i:u32 = 0; i < ${outputShape.length}; i++) { + var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')}); + var input_index: u32; + var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)}; + if (scale == 1.0) { + input_index = u32(output_index); + } else { + var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)}; + var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)}; + var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)}); + var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)}); + var original_idx = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i, + input_shape_i, roi_low, roi_hi); + if (!${useExtrapolation} || (original_idx >= 0 && original_idx < input_shape_i)) { + if (original_idx < 0) { + input_index = 0; + } else if (original_idx > (input_shape_i - 1)) { + input_index = u32(input_shape_i) - 1; } else { - inputIndex = u32(original_idx); + input_index = u32(getNearestPixelFromOriginal(original_idx, scale < 1)); } + } else { + input_index = u32(original_idx); } - ${input.indicesSet('inputIndices', 'i', 'inputIndex')} } - return inputIndices; + ${input.indicesSet('input_indices', 'i', ' input_index')} + } + return input_indices; }`; - const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]): string => ` - fn checkInputIndices(inputIndices: ${input.type.indices}) -> bool { - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); + fn checkInputIndices(input_indices: ${input.type.indices}) -> bool { for (var i:u32 = 0; i < ${inputShape.length}; i++) { - var inputIndex = ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'}; - if (inputIndex < 0 || inputIndex >= inputShape[i]) { + var input_index = ${input.indicesGet('input_indices', 'i')}; + if (input_index < 0 || input_index >= ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}) { return false; } } @@ -316,22 +317,23 @@ const bilinearInterpolation = useExtrapolation: boolean, extrapolationValue: number): string => { const [batchIdx, heightIdx, widthIdx, channelIdx] = inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]); + const dType = input.type.value; return ` - fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> f32 { - var inputIndices: ${input.type.indices}; - inputIndices[${heightIdx}] = max(0, min(row, ${inputShape[heightIdx]} - 1)); - inputIndices[${widthIdx}] = max(0, min(col, ${inputShape[widthIdx]} - 1)); + fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} { + var input_indices: ${input.type.indices}; + ${input.indicesSet('input_indices', heightIdx, `max(0, min(row, ${inputShape[heightIdx]} - 1))`)}; + ${input.indicesSet('input_indices', widthIdx, `max(0, min(col, ${inputShape[widthIdx]} - 1))`)}; if (${inputShape.length} > 2) { - inputIndices[${channelIdx}] = channel; - inputIndices[${batchIdx}] = batch; + ${input.indicesSet('input_indices', channelIdx, 'channel')}; + ${input.indicesSet('input_indices', batchIdx, 'batch')}; }; - return input[${input.indicesToOffset('inputIndices')}]; + return ${input.getByIndices('input_indices')}; } - fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> f32 { - var originalIndices = calculateOriginalIndicesFromOutputIndices(outputIndices); - var row:f32 = originalIndices[${heightIdx}]; - var col:f32 = originalIndices[${widthIdx}]; + fn bilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} { + var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices); + var row:${dType} = originalIndices[${heightIdx}]; + var col:${dType} = originalIndices[${widthIdx}]; if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${ inputShape[widthIdx]} - 1)) { return ${extrapolationValue}; @@ -348,14 +350,14 @@ const bilinearInterpolation = channel = u32(originalIndices[${channelIdx}]); batch = u32(originalIndices[${batchIdx}]); } - var x11: f32 = getInputValue(batch, channel, row1, col1); - var x12: f32 = getInputValue(batch, channel, row1, col2); - var x21: f32 = getInputValue(batch, channel, row2, col1); - var x22: f32 = getInputValue(batch, channel, row2, col2); - var dx1: f32 = row - f32(row1); - var dx2: f32 = f32(row2 ) - row; - var dy1 = col - f32(col1); - var dy2 = f32(col2) - col; + var x11: ${dType} = getInputValue(batch, channel, row1, col1); + var x12: ${dType} = getInputValue(batch, channel, row1, col2); + var x21: ${dType} = getInputValue(batch, channel, row2, col1); + var x22: ${dType} = getInputValue(batch, channel, row2, col2); + var dx1: ${dType} = row - ${dType}(row1); + var dx2: ${dType} = ${dType}(row2) - row; + var dy1 = col - ${dType}(col1); + var dy2 = ${dType}(col2) - col; return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1); }`; }; @@ -365,24 +367,24 @@ const bicubicInterpolation = scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean, extrapolationValue: number, excludeOutside: boolean): string => { const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2]; - + const dType = input.type.value; const createCubicInterpolationFunction = (idx: number): string => { const direction = idx === heightIdx ? 'row' : 'col'; return ` - fn ${direction}CubicInterpolation(inputIndices: ${input.type.indices}, outputIndices: ${ - output.type.indices}) -> f32 { - var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : `outputIndices[${idx}]`}; - var originalIdx: f32 = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), ${scales[idx]}, - f32(${outputShape[idx]}), f32(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); - var fractOriginalIdx: f32 = originalIdx - floor(originalIdx); + fn ${direction}CubicInterpolation(input_indices: ${input.type.indices}, output_indices: ${ + output.type.indices}) -> ${dType} { + var output_index = ${output.indicesGet('output_indices', idx)}; + var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(output_index), ${scales[idx]}, + ${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); + var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx); var coefs = getCubicInterpolationCoefs(fractOriginalIdx); if (${useExtrapolation} && (originalIdx < 0 || originalIdx > (${inputShape[idx]} - 1))) { return ${extrapolationValue}; } - var data: array = array(0.0, 0.0, 0.0, 0.0); + var data: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0); for (var i: i32 = -1; i < 3; i++) { - var ${direction}: f32 = originalIdx + f32(i); + var ${direction}: ${dType} = originalIdx + ${dType}(i); if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) { if (${excludeOutside}) { coefs[i + 1] = 0.0; @@ -393,10 +395,11 @@ const bicubicInterpolation = ${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1)); } } - var inputIndicesCopy: ${input.type.indices} = inputIndices; - inputIndicesCopy[${idx}] = u32(${direction}); - data[i + 1] = ${idx === heightIdx ? `input[${input.indicesToOffset('inputIndicesCopy')}];` : ` - rowCubicInterpolation(inputIndicesCopy, outputIndices);`} + var input_indices_copy: ${input.type.indices} = input_indices; + ${input.indicesSet('input_indices_copy', idx, `u32(${direction})`)}; + data[i + 1] = ${ + idx === heightIdx ? input.getByIndices('input_indices_copy') : + 'rowCubicInterpolation(input_indices_copy, output_indices)'}; } return cubicInterpolation1D(data, coefs); }`; @@ -405,12 +408,12 @@ const bicubicInterpolation = return ` ${createCubicInterpolationFunction(heightIdx)}; ${createCubicInterpolationFunction(widthIdx)}; - fn getCubicInterpolationCoefs(s: f32) -> array { + fn getCubicInterpolationCoefs(s: ${dType}) -> array<${dType}, 4> { var absS = abs(s); - var coeffs: array = array(0.0, 0.0, 0.0, 0.0); - var oneMinusAbsS: f32 = 1.0 - absS; - var twoMinusAbsS: f32 = 2.0 - absS; - var onePlusAbsS: f32 = 1.0 + absS; + var coeffs: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0); + var oneMinusAbsS: ${dType} = 1.0 - absS; + var twoMinusAbsS: ${dType} = 2.0 - absS; + var onePlusAbsS: ${dType} = 1.0 + absS; coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${ cubicCoeffA}) * onePlusAbsS - 4 * ${cubicCoeffA}; coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1; @@ -420,14 +423,14 @@ const bicubicInterpolation = return coeffs; } - fn cubicInterpolation1D(x: array, coefs: array) -> f32 { - var coefsSum: f32 = coefs[0] + coefs[1] + coefs[2] + coefs[3]; + fn cubicInterpolation1D(x: array<${dType}, 4>, coefs: array<${dType}, 4>) -> ${dType} { + var coefsSum: ${dType} = coefs[0] + coefs[1] + coefs[2] + coefs[3]; return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum; } - fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> f32 { - var inputIndices: ${input.type.indices} = outputIndices; - return colCubicInterpolation(inputIndices, outputIndices); + fn bicubicInterpolation(output_indices: ${output.type.indices}) -> ${dType} { + var input_indices: ${input.type.indices} = output_indices; + return colCubicInterpolation(input_indices, output_indices); } `; }; @@ -446,27 +449,28 @@ const createResizeProgramInfo = outputShape = adjustOutputShape(inputShape, scales, attributes); } } - const output = outputVariable('output', inputTensor.dataType, outputShape); - const input = inputVariable('input', inputTensor.dataType, inputShape); + const output = outputVariable('output', inputTensor.dataType, outputShape.length); + const input = inputVariable('input', inputTensor.dataType, inputShape.length); const outputSize = ShapeUtil.size(outputShape); const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]); const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize'; + const dataType = input.type.value; const getShaderSource = (shaderHelper: ShaderHelper) => ` ${noScale ? '' : ` - ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)}; + ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode, dataType)}; ${(() => { switch (attributes.mode) { case 'nearest': return ` ${checkInputIndices(input, inputShape)}; - ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion)}; + ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)}; ${ calculateInputIndicesFromOutputIndices( - input, output, inputShape, outputShape, scales, roi, useExtrapolation)}; + input, output, inputShape, outputShape, scales.length, roi.length, useExtrapolation)}; `; case 'linear': return ` - ${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales, roi)}; + ${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales.length, roi.length)}; ${ bilinearInterpolation( input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)}; @@ -483,25 +487,29 @@ const createResizeProgramInfo = } })()}; `} - ${shaderHelper.declareVariables(input, output)} + ${ + shaderHelper.registerUniform('output_size', 'u32') + .registerUniform('scales', 'f32', scales.length) + .registerUniform('roi', 'f32', roi.length) + .declareVariables(input, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} ${noScale ? 'output[global_idx] = input[global_idx];' : ` - let outputIndices = ${output.offsetToIndices('global_idx')}; - var inputIndices: ${input.type.indices}; + let output_indices = ${output.offsetToIndices('global_idx')}; + var input_indices: ${input.type.indices}; ${(() => { switch (attributes.mode) { case 'nearest': - return `inputIndices = calculateInputIndicesFromOutputIndices(outputIndices); - if (checkInputIndices(inputIndices)) { - output[global_idx] = input[${input.indicesToOffset('inputIndices')}]; + return `input_indices = calculateInputIndicesFromOutputIndices(output_indices); + if (checkInputIndices(input_indices)) { + output[global_idx] = ${input.getByIndices('input_indices')}; } else { output[global_idx] = ${attributes.extrapolationValue}; }`; case 'linear': - return 'output[global_idx] = bilinearInterpolation(outputIndices);'; + return 'output[global_idx] = bilinearInterpolation(output_indices);'; case 'cubic': - return 'output[global_idx] = bicubicInterpolation(outputIndices);'; + return 'output[global_idx] = bicubicInterpolation(output_indices);'; default: throw Error(`Unsupported resize mode: ${attributes.mode}`); } @@ -513,12 +521,20 @@ const createResizeProgramInfo = name: 'Resize', shaderCache: { hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ - sizes.length > 0 ? sizes : ''}|${noScale}` + sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}`, + inputDependencies: ['rank'] }, getShaderSource, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputTensor.dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: [ + {type: 'uint32', data: outputSize}, + {type: 'float32', data: scales}, + {type: 'float32', data: roi}, + ...createTensorShapeVariables(inputShape), + ...createTensorShapeVariables(outputShape), + ] }) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index d607351f69b74..5212c6475dce0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, TensorInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; export interface SliceAttributes extends AttributeWithCacheKey { readonly starts: number[]; @@ -77,21 +77,25 @@ const fixStartEndValues = }; const calculateInputIndicesImpl = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]): - string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { - var inputIndices: ${input.type.indices}; + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[]): string => + `fn calculateInputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { + var input_indices: ${input.type.indices}; var carry = 0u; for (var i = ${inputShape.length}; i >= 0; i--) { - var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; - var inputIndex = outputIndex * steps[i] + starts[i] + carry; - carry = inputIndex / inputShape[i]; - inputIndex = inputIndex % inputShape[i]; - if (signs[i] < 0) { - inputIndex = inputShape[i] - inputIndex - 1u + starts[i]; + let input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}; + let steps_i = ${getElementAt('uniforms.steps', 'i', inputShape.length)}; + let signs_i = ${getElementAt('uniforms.signs', 'i', inputShape.length)}; + let starts_i = ${getElementAt('uniforms.starts', 'i', inputShape.length)}; + var output_index = ${output.indicesGet('output_indices', 'i')}; + var input_index = output_index * steps_i + starts_i + carry; + carry = input_index / input_shape_i; + input_index = input_index % input_shape_i; + if (signs_i < 0) { + input_index = input_shape_i - input_index - 1u + starts_i; } - ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex; + ${input.indicesSet('input_indices', 'i', 'input_index')}; } - return inputIndices; + return input_indices; }`; const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfo => { @@ -110,6 +114,10 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps)); + if (axes.length !== starts.length || axes.length !== ends.length) { + throw new Error('start, ends and axes should have the same number of elements'); + } + if (axes.length !== inputShape.length) { for (let i = 0; i < inputShape.length; ++i) { if (!axes.includes(i)) { @@ -131,40 +139,44 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice array[i] = -step; } }); - + // Output rank is expected to be less than or equal to the input rank. const outputShape = inputShape.slice(0); axes.forEach((axis, _) => { outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]); }); - const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType}; - const output = outputVariable('output', inputs[0].dataType, outputShape); - const input = inputVariable('input', inputs[0].dataType, inputShape); + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length); const outputSize = ShapeUtil.size(outputShape); + const uniforms: UniformsArrayType = [ + {name: 'outputSize', type: 'u32'}, {name: 'starts', type: 'u32', length: starts.length}, + {name: 'signs', type: 'i32', length: signs.length}, {name: 'steps', type: 'u32', length: steps.length} + ]; + + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'uint32', data: starts}, {type: 'int32', data: signs}, + {type: 'uint32', data: steps}, ...createTensorShapeVariables(inputs[0].dims), + ...createTensorShapeVariables(outputShape) + ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, output)} - const signs = array(${signs.map(i => `${i}i`).join(',')}); - const starts = array(${starts.map(i => `${i}u`).join(',')}); - const ends = array(${ends.map(i => `${i}u`).join(',')}); - const steps = array(${steps.map(i => `${i}u`).join(',')}); - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - - ${calculateInputIndicesImpl(input, output, inputShape, outputShape)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)} + ${calculateInputIndicesImpl(input, output, inputShape)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - let outputIndices = ${output.offsetToIndices('global_idx')}; - let inputIndices = calculateInputIndices(outputIndices); - ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + let output_indices = ${output.offsetToIndices('global_idx')}; + let input_indices = calculateInputIndices(output_indices); + ${output.setByOffset('global_idx', input.getByIndices('input_indices'))} }`; return { name: 'Slice', - shaderCache: {hint: `${attributes.cacheKey}|${inputs[4]?.dims ?? ''}`}, + shaderCache: {hint: `${signs.length}_${starts.length}_${steps.length}`, inputDependencies: ['rank']}, getShaderSource, getRunData: () => ({ outputs: [outputTensorInfo], dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)}, + programUniforms }) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 378a7e738dac9..324dc3af1a710 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -73,8 +73,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut } ${shaderHelper.registerUniform('packedCols', 'i32').declareVariables(x, output)} ${shaderHelper.mainStart()} - let gindex = i32(global_id.x); - let lindex = i32(local_id.x); + let gindex = i32(global_idx); + let lindex = i32(local_idx); const wg = ${WG}; let row = gindex / wg; let cols = uniforms.packedCols; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index fd60d81b87ae1..b8582614fa214 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -4,9 +4,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, TensorInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface SplitAttributes extends AttributeWithCacheKey { readonly axis: number; @@ -34,7 +34,7 @@ const createSplitAttributesFromInputs = const calculateOutputIndexImpl = (numberOfTensors: number): string => ` fn calculateOutputIndex(index: u32) -> u32 { for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) { - if (index < sizeInConcatAxis[i]) { + if (index < ${getElementAt('uniforms.size_in_split_axis', 'i', numberOfTensors)}) { return i; } } @@ -48,15 +48,15 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => { if (numberOfTensors === 1) { codeLines.push(returnSnippet); } else if (i === 0) { - codeLines.push(`if (outputNumber == ${i}u) { ${returnSnippet} }`); + codeLines.push(`if (output_number == ${i}u) { ${returnSnippet} }`); } else if (i === numberOfTensors - 1) { codeLines.push(`else { ${returnSnippet} }`); } else { - codeLines.push(`else if (outputNumber == ${i}) { ${returnSnippet} }`); + codeLines.push(`else if (output_number == ${i}) { ${returnSnippet} }`); } } return ` - fn writeBufferData(outputNumber: u32, indices: ${outputs[0].type.indices}, global_idx: u32) { + fn writeBufferData(output_number: u32, indices: ${outputs[0].type.indices}, global_idx: u32) { ${codeLines.join('\n')} }`; }; @@ -65,48 +65,54 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const inputShape = inputs[0].dims; const inputSize = ShapeUtil.size(inputShape); const dataType = inputs[0].dataType; - const rank = inputShape.length; - const axis = attributes.axis; - const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; + const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); const outputs = new Array(attributes.numOutputs); const input = inputVariable('input', dataType, inputShape); - const sizeInConcatAxis = new Array(attributes.numOutputs); + const sizeInSplitAxis = new Array(attributes.numOutputs); const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; let previousSum = 0; + const programUniforms: ProgramUniform[] = [{type: 'uint32', data: inputSize}]; for (let i = 0; i < attributes.numOutputs; i++) { previousSum += attributes.splitSizes[i]; - sizeInConcatAxis[i] = previousSum; + sizeInSplitAxis[i] = previousSum; const outputShape = inputShape.slice(); outputShape[attributes.axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); - outputs[i] = outputVariable(`output${i}`, dataType, outputShapes[i]); + outputs[i] = outputVariable(`output${i}`, dataType, outputShape); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } - const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`; + programUniforms.push({type: 'uint32', data: sizeInSplitAxis}); + programUniforms.push(...createTensorShapeVariables(inputShape)); + outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape))); const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, ...outputs)} - const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); - ${calculateOutputIndexImpl(sizeInConcatAxis.length)} + ${ + shaderHelper.registerUniform('input_size', 'u32') + .registerUniform('size_in_split_axis', 'u32', sizeInSplitAxis.length) + .declareVariables(input, ...outputs)} + ${calculateOutputIndexImpl(sizeInSplitAxis.length)} ${writeBufferDataImpl(outputs)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(inputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.input_size')} var indices = ${input.offsetToIndices('global_idx')}; - let outputNumber = calculateOutputIndex(${indicesAxis}); - if (outputNumber != 0) { - ${indicesAxis} -= sizeInConcatAxis[outputNumber - 1u]; + var index = ${input.indicesGet('indices', axis)}; + let output_number = calculateOutputIndex(index); + if (output_number != 0) { + index -= ${getElementAt('uniforms.size_in_split_axis', 'output_number - 1u', sizeInSplitAxis.length)}; + ${input.indicesSet('indices', axis, 'index')}; } - writeBufferData(outputNumber, indices, global_idx); + writeBufferData(output_number, indices, global_idx); }`; return { name: 'Split', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank']}, getShaderSource, getRunData: () => ({ outputs: outputsTensorInfo, dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)}, + programUniforms }) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index e294541a775ca..90a36a7bec2a9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; const getRepeats = (repeatsTensorView: TensorView): readonly number[] => Array.from(repeatsTensorView.getBigInt64Array(), Number); @@ -54,30 +54,35 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf const outputSize = ShapeUtil.size(outputShape); const dataType = inputs[0].dataType; - const input = inputVariable('input', dataType, inputShape); - const output = outputVariable('output', dataType, outputShape); + const input = inputVariable('input', dataType, inputShape.length); + const output = outputVariable('output', dataType, outputShape.length); const getShaderSource = (shaderHelper: ShaderHelper) => ` const inputShape = ${input.indices(...inputShape)}; - ${shaderHelper.declareVariables(input, output)} + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - let outputIndices = ${output.offsetToIndices('global_idx')}; - var inputIndices: ${input.type.indices}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let output_indices = ${output.offsetToIndices('global_idx')}; + var input_indices: ${input.type.indices}; for (var i = 0; i < ${inputShape.length}; i++) { - let inputDimValue = ${output.indicesGet('outputIndices', 'i')} % ${input.indicesGet('inputShape', 'i')}; + let input_dim_i = ${input.indicesGet('uniforms.input_shape', 'i')}; + let input_dim_value = ${output.indicesGet('output_indices', 'i')} % input_dim_i; - ${input.indicesSet('inputIndices', 'i', 'inputDimValue')} + ${input.indicesSet('input_indices', 'i', 'input_dim_value')} } - ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} + ${output.setByOffset('global_idx', input.getByIndices('input_indices'))} }`; return { name: 'Tile', - shaderCache: {hint: `${repeats}`}, + shaderCache: {hint: `${repeats}`, inputDependencies: ['rank']}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: [ + {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), + ...createTensorShapeVariables(outputShape) + ], }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 4238449f9246f..a25e7fe4229b4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType} from './common'; type BuiltinFunctionName = string; type ElementwiseCustomExpression = (expression: string) => string; @@ -124,8 +124,15 @@ export interface ClipAttributes extends AttributeWithCacheKey { readonly max: number; } -export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => { - const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); +const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { + const min = (inputs.length >= 2 && inputs[1].data !== 0) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; + const max = (inputs.length >= 3 && inputs[2].data !== 0) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; + return createAttributeWithCacheKey({min, max}); +}; + +export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => { + const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs); + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute( createElementwiseProgramInfo( context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, ` @@ -135,16 +142,6 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): vo attributes.cacheKey), {inputs: [0]}); }; -const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { - const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; - const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; - return createAttributeWithCacheKey({min, max}); -}; - -export const clip = (context: ComputeContext): void => { - const attributes = generateClipAttributesFromInputs(context.inputs); - clipV10(context, attributes); -}; export const ceil = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil')); @@ -166,15 +163,16 @@ export const parseAlphaAttributes = (attributes: Record): Alpha createAttributeWithCacheKey(attributes as {alpha: number}); export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( context.inputs[0], 'Elu', a => `elu_vf32(${a})`, ` - const elu_alpha_: f32 = f32(${attributes.alpha}); + const elu_alpha_ = ${dataType}(${attributes.alpha}); - fn elu_f32(a: f32) -> f32 { + fn elu_f32(a: ${dataType}) -> ${dataType} { return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0); } - fn elu_vf32(v: vec4) -> vec4 { + fn elu_vf32(v: vec4<${dataType}>) -> vec4<${dataType}> { return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w)); }`, attributes.cacheKey)); @@ -195,7 +193,7 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} { }`; export const erf = (context: ComputeContext): void => { - const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType))); }; @@ -209,16 +207,17 @@ export const floor = (context: ComputeContext): void => { }; export const gelu = (context: ComputeContext): void => { - const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(`vec4<${dataType}>`, dataType))); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4(0.0))`, - `const leaky_relu_alpha_: f32 = f32(${attributes.alpha});`, attributes.cacheKey)); + context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<${dataType}>(0.0))`, + `const leaky_relu_alpha_ = ${dataType}(${attributes.alpha});`, attributes.cacheKey)); }; export const not = (context: ComputeContext): void => { @@ -234,8 +233,9 @@ export const reciprocal = (context: ComputeContext): void => { }; export const relu = (context: ComputeContext): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Relu', a => `select(vec4(0.0), ${a}, ${a} > vec4(0.0))`)); + context.inputs[0], 'Relu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > vec4<${dataType}>(0.0))`)); }; export const sigmoid = (context: ComputeContext): void => { @@ -263,9 +263,10 @@ export const tanh = (context: ComputeContext): void => { }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'ThresholdedRelu', a => `select(vec4(0.0), ${a}, ${a} > thresholded_relu_alpha_)`, - `const thresholded_relu_alpha_: vec4 = vec4(${attributes.alpha});`, attributes.cacheKey)); + context.inputs[0], 'ThresholdedRelu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`, + `const thresholded_relu_alpha_ = vec4<${dataType}>(${attributes.alpha});`, attributes.cacheKey)); return 0; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 6f66dd86b4088..687ee054096cc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -6,18 +6,15 @@ import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; const createWhereOpProgramShader = (shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean, typeOutput: number) => { - const outputSize = ShapeUtil.size(dimsOutput); - const vecSize = Math.ceil(outputSize / 4); - - const output = outputVariable('outputData', typeOutput, dimsOutput, 4); - const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4); - const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4); - const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4); + const output = outputVariable('output_data', typeOutput, dimsOutput.length, 4); + const a = inputVariable('a_data', inputs[1].dataType, inputs[1].dims.length, 4); + const b = inputVariable('b_data', inputs[2].dataType, inputs[2].dims.length, 4); + const c = inputVariable('c_data', inputs[0].dataType, inputs[0].dims.length, 4); let assignment: string; const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`; @@ -27,20 +24,20 @@ const createWhereOpProgramShader = expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx'))); } else { const singleAssignment = (resStr: string, x: number, typeCast = '') => { - const expressionA = `aData[indexA${x}][componentA${x}]`; - const expressionB = `bData[indexB${x}][componentB${x}]`; + const expressionA = `a_data[index_a${x}][component_a${x}]`; + const expressionB = `b_data[index_b${x}][component_b${x}]`; // eslint-disable-next-line no-bitwise - const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; return ` - let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; - let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; - let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; - let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; - let indexA${x} = offsetA${x} / 4u; - let indexB${x} = offsetB${x} / 4u; - let indexC${x} = offsetC${x} / 4u; - let componentA${x} = offsetA${x} % 4u; - let componentB${x} = offsetB${x} % 4u; + let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; + let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)}; + let offset_b${x} = ${b.broadcastedIndicesToOffset(`output_indices${x}`, output)}; + let offset_c${x} = ${c.broadcastedIndicesToOffset(`output_indices${x}`, output)}; + let index_a${x} = offset_a${x} / 4u; + let index_b${x} = offset_b${x} / 4u; + let index_c${x} = offset_c${x} / 4u; + let component_a${x} = offset_a${x} % 4u; + let component_b${x} = offset_b${x} % 4u; ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); `; }; @@ -51,21 +48,21 @@ const createWhereOpProgramShader = ${singleAssignment('data', 1, 'u32')} ${singleAssignment('data', 2, 'u32')} ${singleAssignment('data', 3, 'u32')} - outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; + output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; } else { assignment = ` - ${singleAssignment('outputData[global_idx]', 0)} - ${singleAssignment('outputData[global_idx]', 1)} - ${singleAssignment('outputData[global_idx]', 2)} - ${singleAssignment('outputData[global_idx]', 3)} + ${singleAssignment('output_data[global_idx]', 0)} + ${singleAssignment('output_data[global_idx]', 1)} + ${singleAssignment('output_data[global_idx]', 2)} + ${singleAssignment('output_data[global_idx]', 3)} `; } } return ` - ${shaderHelper.declareVariables(c, a, b, output)} + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(c, a, b, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} ${assignment} }`; }; @@ -79,6 +76,7 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC)); let outputShape = dimsA; let outputSize = ShapeUtil.size(dimsA); + const vecSize = Math.ceil(outputSize / 4); // TODO: deal with zero-sized tensors (eg. dims=[1,0]) if (isBroadcast) { @@ -92,11 +90,16 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => return { name: 'Where', + shaderCache: {inputDependencies: ['rank', 'rank', 'rank']}, getShaderSource: (shaderHelper) => createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType), getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}, + programUniforms: [ + {type: 'uint32', data: vecSize}, ...createTensorShapeVariables(dimsC), ...createTensorShapeVariables(dimsA), + ...createTensorShapeVariables(dimsB), ...createTensorShapeVariables(outputShape) + ], }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 0b0a545f46481..ae5bf68483b46 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -75,12 +75,11 @@ export class ProgramManager { const kernelId = this.backend.currentKernelId!; const kernelInfo = this.backend.kernels.get(kernelId)!; - const kernelName = `[${kernelInfo[0]}] ${kernelInfo[1]}`; void syncData.buffer.mapAsync(GPUMapMode.READ).then(() => { const mappedData = new BigUint64Array(syncData.buffer.getMappedRange()); - const startTimeU64 = mappedData[0]; - const endTimeU64 = mappedData[1]; + const [startTimeU64, endTimeU64] = mappedData; + const [kernelType, kernelName] = kernelInfo; syncData.buffer.unmap(); @@ -96,17 +95,33 @@ export class ProgramManager { } this.backend.gpuDataManager.release(syncData.id); - let inputShapes = ''; - inputTensorViews.forEach((value, i) => { - inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; - }); - let outputShapes = ''; - outputTensorViews.forEach((value, i) => { - outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; - }); - // eslint-disable-next-line no-console - console.log(`[profiling] kernel "${kernelId}|${kernelName}" ${inputShapes}${outputShapes}execution time: ${ - endTime - startTime} ns`); + if (this.backend.env.webgpu.profiling?.ondata) { + this.backend.env.webgpu.profiling.ondata({ + version: 1, + inputsMetadata: inputTensorViews.map( + value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})), + outputsMetadata: outputTensorViews.map( + value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})), + kernelId, + kernelType, + kernelName, + startTime, + endTime, + }); + } else { + // if no callback is provided, print the profiling message to console + let inputShapes = ''; + inputTensorViews.forEach((value, i) => { + inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; + }); + let outputShapes = ''; + outputTensorViews.forEach((value, i) => { + outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; + }); + // eslint-disable-next-line no-console + console.log(`[profiling] kernel "${kernelId}|${kernelName}|${buildArtifact.programInfo.name}" ${inputShapes}${ + outputShapes}execution time: ${endTime - startTime} ns`); + } }); } diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index 09d91591128d1..71815f21e650a 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -1,28 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; +import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, lazyResetGrad, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } private sessionId: number; private checkpointId: number; inputNames: string[]; outputNames: string[]; - inputEncodedNames: number[]; - outputEncodedNames: number[]; + evalInputNames: string[] = []; + evalOutputNames: string[] = []; async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { let buffer: Uint8Array; @@ -57,8 +51,12 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } this.checkpointId = createCheckpointHandle(checkpointData); - [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = + this.sessionId = createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + [this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false); + if (evalModelUriOrBuffer !== '') { + [this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true); + } } /** @@ -107,6 +105,10 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes return resultMap; } + async lazyResetGrad(): Promise { + await lazyResetGrad(this.sessionId); + } + async runTrainStep( feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise { @@ -124,8 +126,40 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); } + async runOptimizerStep(options: InferenceSession.RunOptions): Promise { + await runOptimizerStep(this.sessionId, options); + } + + async runEvalStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise { + const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( + feeds, this.evalInputNames, + (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`)); + + const [outputArray, outputIndices, outputs] = + this.convertMapIntoValuesArrayAndIndicesArray( + fetches, this.evalOutputNames, + (t, i): TensorMetadata|null => + t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null); + + const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); + return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); + } + + async getParametersSize(trainableOnly: boolean): Promise { + return getParametersSize(this.sessionId, trainableOnly); + } + + async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { + await loadParametersBuffer(this.sessionId, array, trainableOnly); + } + async getContiguousParameters(trainableOnly: boolean): Promise { + const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly); + return decodeTensorMetadata(tensorResult); + } + async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint( - this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId); } } diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index a35d285346db4..0cc28188a6093 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -3,10 +3,10 @@ import {InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; +import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; -import {tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {prepareInputOutputTensor} from './wasm-core-impl'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; @@ -16,6 +16,22 @@ const NO_TRAIN_FUNCS_MSG = 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; +/** + * Runs the checkLastError function which will throw an error, if the provided error code matches the specified + * pattern for an error code. + * @param errCode number to evaluated for if it's an error + * @param message message to pass into checkLastError + * @param checkNeqZero when true, treats not equal to zero as an error. + * When false, treats equal to zero as an error. + */ +const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => { + if (checkNeqZero && errCode !== 0) { + checkLastError(message); + } else if (!checkNeqZero && errCode === 0) { + checkLastError(message); + } +}; + export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); @@ -29,9 +45,7 @@ export const createCheckpointHandle = (checkpointData: SerializableModeldata): n throw new Error(NO_TRAIN_FUNCS_MSG); } - if (checkpointHandle === 0) { - checkLastError('Error occurred when trying to create a CheckpointState.'); - } + ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false); return checkpointHandle; } catch (e) { if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { @@ -52,9 +66,7 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea if (wasm._OrtTrainingGetModelInputOutputCount) { const errorCode = wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel); - if (errorCode !== 0) { - checkLastError('Can\'t get session input/output count.'); - } + ifErrCodeCheckLastError(errorCode, 'Can\'t get session input/output count.'); return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; } else { throw new Error(NO_TRAIN_FUNCS_MSG); @@ -65,52 +77,44 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea }; const getModelInputOutputNamesLoop = - (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => { + (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): string[] => { const names = []; const wasm = getInstance(); - const namesUTF8Encoded = []; - for (let i = 0; i < count; i++) { if (wasm._OrtTrainingGetModelInputOutputName) { const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); - if (name === 0) { - checkLastError('Can\'t get input or output name'); - } + ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); - namesUTF8Encoded.push(name); names.push(wasm.UTF8ToString(name)); + wasm._free(name); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } } - return [names, namesUTF8Encoded]; + return names; }; -const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { - const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false); +export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => { + let inputNames: string[] = []; + let outputNames: string[] = []; + + const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel); - const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false); - const [outputNames, outputNamesUTF8Encoded] = - getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false); + inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel); + outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel); - return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; + return [inputNames, outputNames]; }; export const createTrainingSessionHandle = (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, - optimizerModelData: SerializableModeldata, - options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => { + optimizerModelData: SerializableModeldata, options: InferenceSession.SessionOptions): number => { const wasm = getInstance(); let trainingSessionHandle = 0; let sessionOptionsHandle = 0; let allocs: number[] = []; - let inputNamesUTF8Encoded: number[] = []; - let outputNamesUTF8Encoded: number[] = []; - - let inputNames: string[] = []; - let outputNames: string[] = []; try { [sessionOptionsHandle, allocs] = setSessionOptions(options); @@ -122,14 +126,8 @@ export const createTrainingSessionHandle = throw new Error(NO_TRAIN_FUNCS_MSG); } - if (trainingSessionHandle === 0) { - checkLastError('Error occurred when trying to create a TrainingSession.'); - } - - [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = - getTrainingModelInputOutputNames(trainingSessionHandle); - return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded]; - + ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); + return trainingSessionHandle; } catch (e) { if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { wasm._OrtTrainingReleaseSession(trainingSessionHandle); @@ -144,8 +142,6 @@ export const createTrainingSessionHandle = wasm._OrtReleaseSessionOptions(sessionOptionsHandle); } allocs.forEach(alloc => wasm._free(alloc)); - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); } }; @@ -213,9 +209,8 @@ const moveOutputToTensorMetadataArr = try { const errorCode = wasm._OrtGetTensorData( tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); - } + ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); + let tensorDataIndex = tensorDataOffset / 4; const dataType = wasm.HEAPU32[tensorDataIndex++]; dataOffset = wasm.HEAPU32[tensorDataIndex++]; @@ -258,6 +253,17 @@ const moveOutputToTensorMetadataArr = return output; }; +export const lazyResetGrad = async(trainingSessionId: number): Promise => { + const wasm = getInstance(); + + if (wasm._OrtTrainingLazyResetGrad) { + const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId); + ifErrCodeCheckLastError(errorCode, 'Can\'t call lazyResetGrad.'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } +}; + export const runTrainStep = async( trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], outputTensors: Array, options: InferenceSession.RunOptions): Promise => { @@ -290,10 +296,84 @@ export const runTrainStep = async( if (wasm._OrtTrainingRunTrainStep) { const errorCode = wasm._OrtTrainingRunTrainStep( trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); + ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } - if (errorCode !== 0) { - checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); - } + return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); + } finally { + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); + + if (runOptionsHandle !== 0) { + wasm._OrtReleaseRunOptions(runOptionsHandle); + } + runOptionsAllocs.forEach(p => wasm._free(p)); + } +}; + +export const runOptimizerStep = + async(trainingSessionId: number, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); + + let runOptionsHandle = 0; + let runOptionsAllocs: number[] = []; + + try { + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + + if (wasm._OrtTrainingOptimizerStep) { + const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle); + ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + if (runOptionsHandle !== 0) { + wasm._OrtReleaseRunOptions(runOptionsHandle); + } + runOptionsAllocs.forEach(p => wasm._free(p)); + } +}; + +export const runEvalStep = async( + trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); + + const inputCount = inputIndices.length; + const outputCount = outputIndices.length; + + let runOptionsHandle = 0; + let runOptionsAllocs: number[] = []; + + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + + try { + // prepare parameters by moving them to heap + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + + // handle inputs -- you don't want anything added to the index + const inputValuesOffset = createAndAllocateTensors( + trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); + // handle outputs + // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor + const outputValuesOffset = createAndAllocateTensors( + trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); + + if (wasm._OrtTrainingEvalStep) { + const errorCode = wasm._OrtTrainingEvalStep( + trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); + + ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -313,17 +393,135 @@ export const runTrainStep = async( } }; -export const releaseTrainingSessionAndCheckpoint = - (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): - void => { - const wasm = getInstance(); - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); +export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { + const wasm = getInstance(); + const stack = wasm.stackSave(); - if (wasm._OrtTrainingReleaseSession) { - wasm._OrtTrainingReleaseSession(sessionId); - } - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } - }; + try { + const sizeOffset = wasm.stackAlloc(4); + if (wasm._OrtTrainingGetParametersSize) { + const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); + ifErrCodeCheckLastError(errorCode, 'Can\'t get parameters size'); + + return wasm.HEAP32[sizeOffset / 4]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } +}; + +export const getContiguousParameters = + async(trainingSessionId: number, trainableOnly: boolean): Promise => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + const tensorTypeAsString = 'float32'; + const locationAsString = 'cpu'; + + const parametersSize = getParametersSize(trainingSessionId, trainableOnly); + let tensor = 0; + + // allocates a buffer of the correct size on the WASM heap + const paramsByteLength = 4 * parametersSize; + const paramsOffset = wasm._malloc(paramsByteLength); + + // handles the dimensions-related createTensor parameters + const dims = [parametersSize]; + + const dimsOffset = wasm.stackAlloc(4); + const dimsIndex = dimsOffset / 4; + wasm.HEAP32[dimsIndex] = parametersSize; + + try { + // wraps allocated array in a tensor + tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(tensorTypeAsString), paramsOffset, paramsByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(locationAsString)); + ifErrCodeCheckLastError( + tensor, `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, false); + + if (wasm._OrtTrainingCopyParametersToBuffer) { + const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); + ifErrCodeCheckLastError(errCode, 'Can\'t get contiguous parameters.'); + + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + // copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); + const data = new typedArrayConstructor(parametersSize); + const output: TensorMetadata[] = []; + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength)); + output.push([tensorTypeAsString, dims, data, locationAsString]); + if (output.length !== 1) { + throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of + one, got ${output.length}`); + } else { + return output[0]; + } + } finally { + if (tensor !== 0) { + wasm._OrtReleaseTensor(tensor); + } + wasm._free(paramsOffset); + wasm._free(dimsOffset); + wasm.stackRestore(stack); + } +}; + +export const loadParametersBuffer = + async(trainingSessionId: number, buffer: Uint8Array, trainableOnly: boolean): Promise => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + const tensorTypeAsString = 'float32'; + const locationAsString = 'cpu'; + + // allocates & copies JavaScript buffer to WASM heap + const bufferByteLength = buffer.length; + const bufferCount = bufferByteLength / 4; + const bufferOffset = wasm._malloc(bufferByteLength); + wasm.HEAPU8.set(buffer, bufferOffset); + + // allocates and handles moving dimensions information to WASM memory + const dimsOffset = wasm.stackAlloc(4); + wasm.HEAP32[dimsOffset / 4] = bufferCount; + const dimsLength = 1; + let tensor = 0; + + try { + tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(tensorTypeAsString), bufferOffset, bufferByteLength, dimsOffset, dimsLength, + dataLocationStringToEnum(locationAsString)); + ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); + + if (wasm._OrtTrainingCopyParametersFromBuffer) { + const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); + ifErrCodeCheckLastError(errCode, 'Can\'t copy buffer to parameters.'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + if (tensor !== 0) { + wasm._OrtReleaseTensor(tensor); + } + wasm.stackRestore(stack); + wasm._free(bufferOffset); + wasm._free(dimsOffset); + } +}; + +export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => { + const wasm = getInstance(); + + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } +}; diff --git a/js/web/script/generate-webgpu-operator-md.ts b/js/web/script/generate-webgpu-operator-md.ts index 7408f17004f5e..eab8175a941bd 100644 --- a/js/web/script/generate-webgpu-operator-md.ts +++ b/js/web/script/generate-webgpu-operator-md.ts @@ -16,6 +16,8 @@ const COMMENTS: Record = { 'Reshape': 'no GPU kernel', 'Shape': 'no GPU kernel; an ORT warning is generated - need to fix', 'Resize': 'CoordinateTransformMode align_corners is not supported with downsampling', + 'Attention': 'need implementing mask and past/present', + 'MultiHeadAttention': 'need implementing mask and past/present', }; /* eslint-disable max-len */ diff --git a/js/web/test/data/ops/attention.jsonc b/js/web/test/data/ops/attention.jsonc new file mode 100644 index 0000000000000..bd4483027cc25 --- /dev/null +++ b/js/web/test/data/ops/attention.jsonc @@ -0,0 +1,557 @@ +[ + { + "name": "Attention Basic", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [4, 3], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [213, 213], + "dims": [1, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic Batch 2 with 2 heads", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16 + ], + "dims": [2, 2, 8], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 + ], + "dims": [8, 6], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [320, 321, 320, 321, 320, 321, 320, 321], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863], + "dims": [1, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-1.328187108039856, -1.297916054725647, -0.8599594831466675], + "dims": [1, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic one head, batch 2", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987 + ], + "dims": [2, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 2 head, batch 1", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643], + "dims": [2, 6], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989, -0.989, 1.1103, -1.6898], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.8701779842376709, -2.6158859729766846, 0.8710794448852539, -2.5763747692108154, 0.9005484580993652, + -2.182751178741455, 2.1661579608917236, -2.1045265197753906, 1.6716957092285156, -1.797281265258789, + 1.7134947776794434, -1.765358328819275 + ], + "dims": [2, 3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 2", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987 + ], + "dims": [2, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [ + 1.1103, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, -1.6898, -0.989, -1.9029953479766846, 0.8710794448852539, + -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, 1.7134947776794434 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.6956915855407715, -2.8863370418548584, 1.3899128437042236, 1.6789076328277588, -1.4083852767944336, + -1.7009180784225464, -3.1053788661956787, 3.5959298610687256, 1.1027096509933472, -0.009643087163567543, + -1.694351315498352, -2.9284396171569824, 1.734721302986145, 2.0606398582458496, -0.2571452260017395, + 3.671973943710327, -5.285338401794434, -6.833454132080078, 1.7506506443023682, -2.262148380279541, + 2.5110034942626953, 1.440049171447754, -0.9423203468322754, 1.7506506443023682, -1.86212158203125, + -0.5036701560020447, -5.732386589050293, -1.5674757957458496, 1.7506510019302368, -2.264472246170044 + ], + "dims": [2, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 1", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846 + ], + "dims": [1, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [1, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987 + ], + "dims": [3, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326, + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326, + 3.7965505123138428, -2.3799397945404053, -3.9530906677246094, 0.5844926834106445, -2.9756431579589844, + 2.448162794113159, 4.34546422958374, 1.9380426406860352, 0.5870105624198914, -2.7368364334106445, + -0.4769568145275116, 4.255186557769775, -3.9529950618743896, 0.6987408995628357, -2.9756433963775635 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.01101303100586, -5.782258987426758, 6.016238689422607, 0.26747000217437744, -6.992541313171387, + -8.011263847351074, -5.782248020172119, 5.366001129150391, 0.26747000217437744, -6.99449348449707, + -8.011263847351074, -5.782265663146973, 6.016238689422607, 0.26747000217437744, -6.992537021636963, + -6.102723598480225, -7.28973388671875, -4.578637599945068, 7.2203369140625, -6.028444766998291, + -6.102705478668213, -7.2897748947143555, -3.7882626056671143, 5.393260478973389, -5.754333972930908, + -1.3616288900375366, -7.289827823638916, -6.341128349304199, 6.329389572143555, -5.751791954040527, + -2.3945987224578857, -14.532954216003418, 3.969801902770996, 12.744998931884766, -11.1966552734375, + -2.4002532958984375, -14.538958549499512, -6.684961318969727, 12.476543426513672, -9.24352741241455, + -4.787771701812744, -8.640848159790039, 3.969801902770996, -0.6471102833747864, -11.1966552734375 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 1 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + 1.3541864156723022, -7.813620090484619, -6.758509635925293, 7.597365856170654, -13.926229476928711, + -1.322464108467102, -7.297357559204102, -0.05962071940302849, 6.347561836242676, -5.869992256164551, + -1.3616288900375366, -7.28973388671875, 0.0386197566986084, 6.329389572143555, -5.751791954040527, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + 1.021930456161499, -2.373898983001709, 3.8501391410827637, -0.6108309626579285, -9.256340980529785 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/batch-norm.jsonc b/js/web/test/data/ops/batch-norm.jsonc new file mode 100644 index 0000000000000..4ea16f290dc8f --- /dev/null +++ b/js/web/test/data/ops/batch-norm.jsonc @@ -0,0 +1,446 @@ +[ + { + "name": "BatchNormalization with no attributes", + "operator": "BatchNormalization", + "attributes": [], + "cases": [ + { + "name": "T[64]", + "inputs": [ + { + "data": [ + 2.02384, -0.935186, 0.488569, -0.513934, -1.27082, -0.131913, -1.806, -0.37904, 0.667796, -1.14826, + 1.2522, 0.0300339, 2.4758, 1.55511, 0.385341, 1.46645, -1.09355, -2.56309, 0.976015, -1.47036, 0.89486, + 0.580989, -1.12418, -0.339189, 1.3314, 0.418893, -0.301401, -1.2983, -0.839063, 0.170261, 1.15486, + -0.255735, -0.589851, -0.416289, -0.952648, -0.360487, 0.253287, 0.437195, 0.32023, 0.209606, -0.279519, + -0.546527, 0.265286, -1.07383, -1.65879, 1.1222, 0.946612, 0.822549, 0.64689, -0.292639, -0.73995, + -0.694949, 1.33899, -0.0652476, 1.61791, 1.49692, -0.761145, -0.201874, -1.15431, -1.83111, -0.705267, + -0.143026, -0.129819, -0.799425 + ], + "dims": [64], + "type": "float32" + }, + { + "data": [0.241661], + "dims": [1], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "float32" + }, + { + "data": [1], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.489082, -0.225997, 0.118068, -0.124197, -0.307105, -0.031878, -0.436439, -0.0915989, 0.16138, -0.277489, + 0.302606, 0.007258, 0.598301, 0.375807, 0.0931215, 0.354382, -0.264267, -0.619395, 0.235864, -0.355328, + 0.216252, 0.140402, -0.271669, -0.0819684, 0.321747, 0.10123, -0.0728365, -0.313746, -0.202768, 0.0411454, + 0.279085, -0.0618009, -0.142543, -0.1006, -0.230217, -0.0871152, 0.0612094, 0.105652, 0.0773867, + 0.0506533, -0.0675486, -0.132074, 0.064109, -0.259501, -0.400863, 0.271191, 0.228758, 0.198777, 0.156327, + -0.0707191, -0.178816, -0.167941, 0.323581, -0.0157677, 0.390985, 0.361745, -0.183938, -0.0487849, + -0.27895, -0.442507, -0.170435, -0.0345637, -0.031372, -0.193189 + ], + "dims": [64], + "type": "float32" + } + ] + }, + { + "name": "T[2,3,4,4,4]", + "inputs": [ + { + "data": [ + 2.02384, -0.935186, 0.488569, -0.513934, -1.27082, -0.131913, -1.806, -0.37904, 0.667796, -1.14826, + 1.2522, 0.0300339, 2.4758, 1.55511, 0.385341, 1.46645, -1.09355, -2.56309, 0.976015, -1.47036, 0.89486, + 0.580989, -1.12418, -0.339189, 1.3314, 0.418893, -0.301401, -1.2983, -0.839063, 0.170261, 1.15486, + -0.255735, -0.589851, -0.416289, -0.952648, -0.360487, 0.253287, 0.437195, 0.32023, 0.209606, -0.279519, + -0.546527, 0.265286, -1.07383, -1.65879, 1.1222, 0.946612, 0.822549, 0.64689, -0.292639, -0.73995, + -0.694949, 1.33899, -0.0652476, 1.61791, 1.49692, -0.761145, -0.201874, -1.15431, -1.83111, -0.705267, + -0.143026, -0.129819, -0.799425, 0.168795, 0.740422, -0.377683, 0.432598, -2.07414, -2.85251, 0.273531, + 0.0532606, 1.31052, -0.769382, 0.9976, 0.850536, -1.53812, -0.00496016, 0.931242, 0.0517056, -0.497829, + 0.275869, 0.860001, 1.23747, 0.179686, 1.5914, 0.740327, 0.798208, 2.12478, 1.74205, -0.322054, + -0.0112451, 0.204525, -0.431252, -1.3114, 0.186204, 0.780569, -1.42994, 1.63344, -0.00839034, -0.187035, + 1.8406, 1.32053, -0.636963, 0.408944, -1.50846, -1.2076, -0.129118, -0.0441307, 1.47558, 1.07251, 1.05295, + -0.420297, -1.13402, -0.524053, 3.20754, -0.588935, -0.527549, 0.591928, -1.10529, 0.520412, 0.19404, + -1.21229, -0.399594, -0.280935, -0.363324, -0.00804771, 1.43102, -0.523222, 1.17608, -0.53195, 0.914993, + 2.69308, -0.517211, 0.472273, -0.464725, -0.929768, -0.631145, 0.919709, -0.27391, 1.76689, 0.894897, + 0.235798, 1.2544, 0.858985, -0.139707, 0.354544, 0.200878, 0.353255, 0.0722632, -1.56074, 1.03685, + 1.73434, 0.193269, -0.864609, 0.842739, -0.372717, 0.584484, 0.16315, 1.60674, -0.0611289, -1.24544, + 1.33361, -0.961942, -0.15732, -0.348637, 0.361842, 0.7386, 0.517256, 1.20406, -2.07277, -1.01983, -1.9163, + 0.239934, 0.177979, 0.464564, 0.988822, 0.284607, -1.56099, -0.429143, 0.111043, -0.0853688, -0.319176, + -0.279777, 0.520971, -1.078, -0.670242, 0.065652, 0.468538, -0.825062, 0.370068, 1.68751, -1.16928, + -0.411782, 1.61624, -0.973004, 2.64703, -0.220014, -1.43954, -0.018692, 1.34982, -0.95197, -1.72586, + 1.32725, 0.280984, 0.00847463, 0.512869, 0.0378154, 0.13898, 0.35758, -0.084558, 1.04045, -1.79933, + 1.3002, 0.390457, 1.22267, 0.959344, -0.964296, -0.0935597, 0.288953, -0.158046, 0.532672, -0.500988, + 0.25187, -2.14384, -0.633315, 1.24612, -1.41525, 0.36494, -0.00714732, -0.608963, 0.508496, 0.995365, + 1.21159, -0.169055, -0.968783, 1.52779, -0.082381, 2.2049, 0.928655, 0.120245, 0.911429, -0.885258, + -1.2072, 0.770694, 2.36621, 1.08456, -1.60069, 0.0345025, 0.359559, -0.785411, 0.466532, -0.78543, + 0.024879, 1.59337, 1.13718, -1.27073, -0.263788, -1.7702, 0.203263, 1.34631, 1.11914, -2.04911, -0.804137, + 0.466763, 2.18386, 1.4689, 0.898297, -0.648948, 0.252202, 1.12501, -0.204563, 0.124608, 0.377214, + 0.894327, -0.249118, 0.709188, 0.999397, -1.4079, 0.193873, 0.657753, -0.709732, 1.09897, -0.145793, + 0.779199, 0.88378, -1.2676, 1.15709, 0.62295, -0.370894, -0.103268, -1.55949, -0.470747, 0.100394, + 0.422334, -0.0685312, -0.434488, -0.568974, -0.256987, 2.01276, -0.923322, -0.613144, 1.50676, 0.65756, + 1.20524, 1.10395, -0.975241, 2.44035, 1.08276, 0.330393, -0.508918, -1.25545, 0.189815, -0.156263, + -0.960866, 1.0859, -0.674478, 2.76743, 1.21399, 1.71666, -1.73198, -1.1062, 0.951285, -0.713336, 1.61586, + 1.96514, 0.002603, 0.0953297, 0.949256, -1.76552, 0.372816, -0.781229, 1.50532, 1.28462, 1.31116, + 0.731908, 1.54835, 0.371081, 0.409244, -0.106938, -1.79396, -1.61198, -0.80869, -1.10381, 1.1872, + -0.832439, 0.0755941, -1.09553, 0.960059, 1.44252, -0.196482, -1.07364, 0.165547, 0.630078, 1.56569, + -0.669592, 1.15974, 0.0953399, -0.202313, 0.812631, -0.318567, -0.16644, 0.887062, -0.0264821, -0.740725, + 0.0797577, -1.1037, 0.90236, 1.13427, 0.364186, -2.01043, -0.415748, 0.116046, 0.369949, 0.317886, + 0.530332, 1.48341, 0.74666, -1.64142, 0.22569, 1.18015, 1.31827, -1.33904, -0.101125 + ], + "dims": [2, 3, 4, 4, 4], + "type": "float32" + }, + { + "data": [0.241661, 0.960798, 0.474727], + "dims": [3], + "type": "float32" + }, + { + "data": [0, 0, 0], + "dims": [3], + "type": "float32" + }, + { + "data": [0, 0, 0], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 1, 1], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.489082, -0.225997, 0.118068, -0.124197, -0.307105, -0.031878, -0.436439, -0.0915989, 0.16138, -0.277489, + 0.302606, 0.007258, 0.598301, 0.375807, 0.0931215, 0.354382, -0.264267, -0.619395, 0.235864, -0.355328, + 0.216252, 0.140402, -0.271669, -0.0819684, 0.321747, 0.10123, -0.0728365, -0.313746, -0.202768, 0.0411454, + 0.279085, -0.0618009, -0.142543, -0.1006, -0.230217, -0.0871152, 0.0612094, 0.105652, 0.0773867, + 0.0506533, -0.0675486, -0.132074, 0.064109, -0.259501, -0.400863, 0.271191, 0.228758, 0.198777, 0.156327, + -0.0707191, -0.178816, -0.167941, 0.323581, -0.0157677, 0.390985, 0.361745, -0.183938, -0.0487849, + -0.27895, -0.442507, -0.170435, -0.0345637, -0.031372, -0.193189, 0.162177, 0.711393, -0.362876, 0.415637, + -1.99282, -2.74067, 0.262807, 0.0511725, 1.25914, -0.739217, 0.958488, 0.817189, -1.47782, -0.00476569, + 0.894731, 0.0496784, -0.478311, 0.265053, 0.826283, 1.18895, 0.172641, 1.52901, 0.711301, 0.766913, + 2.04147, 1.67375, -0.309427, -0.0108042, 0.196507, -0.414344, -1.25999, 0.178903, 0.749965, -1.37387, + 1.5694, -0.00806138, -0.179702, 1.76844, 1.26875, -0.61199, 0.392911, -1.44932, -1.16025, -0.124055, + -0.0424004, 1.41773, 1.03046, 1.01167, -0.403818, -1.08956, -0.503507, 3.08178, -0.565845, -0.506866, + 0.56872, -1.06196, 0.500008, 0.186433, -1.16476, -0.383928, -0.269921, -0.349079, -0.00773219, 1.37492, + -0.248386, 0.558316, -0.25253, 0.43437, 1.27847, -0.245533, 0.2242, -0.220617, -0.441384, -0.29962, + 0.436609, -0.130032, 0.838785, 0.424829, 0.111939, 0.595496, 0.407781, -0.0663221, 0.168311, 0.0953618, + 0.167699, 0.0343051, -0.74092, 0.492219, 0.823334, 0.0917494, -0.410451, 0.400069, -0.176938, 0.277469, + 0.0774512, 0.762761, -0.0290194, -0.59124, 0.6331, -0.456657, -0.0746837, -0.165507, 0.171775, 0.350631, + 0.245554, 0.571595, -0.983996, -0.484139, -0.909715, 0.113902, 0.0844908, 0.22054, 0.469418, 0.13511, + -0.741041, -0.203725, 0.0527148, -0.0405267, -0.151521, -0.132817, 0.247318, -0.511752, -0.31818, + 0.0311666, 0.222426, -0.391677, 0.17568, 0.801104, -0.282569, -0.0995112, 0.39058, -0.235136, 0.639682, + -0.0531687, -0.347878, -0.0045171, 0.326198, -0.230053, -0.41707, 0.320744, 0.0679025, 0.00204798, + 0.12394, 0.00913847, 0.0335859, 0.0864127, -0.0204343, 0.251436, -0.434827, 0.314206, 0.0943579, 0.295471, + 0.231835, -0.233032, -0.0226096, 0.0698283, -0.0381934, 0.128725, -0.121069, 0.060867, -0.51808, + -0.153047, 0.301137, -0.342009, 0.0881915, -0.00172722, -0.147162, 0.122883, 0.24054, 0.292792, + -0.0408538, -0.234116, 0.369206, -0.0199082, 0.532835, 0.224419, 0.0290583, 0.220256, -0.213931, + -0.291733, 0.186246, 0.571817, 0.262095, -0.386822, 0.00833788, 0.086891, -0.189802, 0.112742, -0.189807, + 0.00601226, 0.385054, 0.274811, -1.22091, -0.253445, -1.7008, 0.195294, 1.29353, 1.07526, -1.96877, + -0.772609, 0.448463, 2.09824, 1.4113, 0.863078, -0.623505, 0.242314, 1.0809, -0.196543, 0.119722, + 0.362425, 0.859263, -0.239351, 0.681383, 0.960214, -1.3527, 0.186272, 0.631964, -0.681905, 1.05588, + -0.140077, 0.748649, 0.84913, -1.2179, 1.11172, 0.598526, -0.356353, -0.099219, -1.49835, -0.452291, + 0.0964582, 0.405776, -0.0658444, -0.417454, -0.546667, -0.246911, 1.93385, -0.887121, -0.589104, 1.44769, + 0.631779, 1.15798, 1.06067, -0.937005, 2.34467, 1.04031, 0.31744, -0.488965, -1.20623, 0.182373, + -0.150136, -0.923194, 1.04332, -0.648034, 2.65893, 1.1664, 1.64935, -0.822216, -0.525139, 0.451599, + -0.338638, 0.767087, 0.932899, 0.00123571, 0.0452554, 0.450635, -0.838136, 0.176985, -0.370868, 0.714614, + 0.60984, 0.622438, 0.347455, 0.73504, 0.176161, 0.194278, -0.0507662, -0.851639, -0.765246, -0.383905, + -0.524005, 0.563593, -0.395179, 0.0358864, -0.520076, 0.455763, 0.684801, -0.093275, -0.509682, 0.0785892, + 0.299113, 0.743272, -0.317872, 0.550556, 0.0452602, -0.0960432, 0.385776, -0.151232, -0.079013, 0.42111, + -0.0125717, -0.35164, 0.0378629, -0.523955, 0.428372, 0.538468, 0.172888, -0.954402, -0.197366, 0.0550898, + 0.175624, 0.150908, 0.251761, 0.704209, 0.354458, -0.779221, 0.107141, 0.560244, 0.625814, -0.635675, + -0.0480064 + ], + "dims": [2, 3, 4, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "BatchNormalization with no attributes - NHWC", + "operator": "BatchNormalization", + "opset": { "domain": "com.ms.internal.nhwc", "version": 12 }, + "attributes": [], + "cases": [ + { + "name": "T[64]", + "inputs": [ + { + "data": [ + 2.02384, -0.935186, 0.488569, -0.513934, -1.27082, -0.131913, -1.806, -0.37904, 0.667796, -1.14826, + 1.2522, 0.0300339, 2.4758, 1.55511, 0.385341, 1.46645, -1.09355, -2.56309, 0.976015, -1.47036, 0.89486, + 0.580989, -1.12418, -0.339189, 1.3314, 0.418893, -0.301401, -1.2983, -0.839063, 0.170261, 1.15486, + -0.255735, -0.589851, -0.416289, -0.952648, -0.360487, 0.253287, 0.437195, 0.32023, 0.209606, -0.279519, + -0.546527, 0.265286, -1.07383, -1.65879, 1.1222, 0.946612, 0.822549, 0.64689, -0.292639, -0.73995, + -0.694949, 1.33899, -0.0652476, 1.61791, 1.49692, -0.761145, -0.201874, -1.15431, -1.83111, -0.705267, + -0.143026, -0.129819, -0.799425 + ], + "dims": [64], + "type": "float32" + }, + { + "data": [0.241661], + "dims": [1], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "float32" + }, + { + "data": [1], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.489082, -0.225997, 0.118068, -0.124197, -0.307105, -0.031878, -0.436439, -0.0915989, 0.16138, -0.277489, + 0.302606, 0.007258, 0.598301, 0.375807, 0.0931215, 0.354382, -0.264267, -0.619395, 0.235864, -0.355328, + 0.216252, 0.140402, -0.271669, -0.0819684, 0.321747, 0.10123, -0.0728365, -0.313746, -0.202768, 0.0411454, + 0.279085, -0.0618009, -0.142543, -0.1006, -0.230217, -0.0871152, 0.0612094, 0.105652, 0.0773867, + 0.0506533, -0.0675486, -0.132074, 0.064109, -0.259501, -0.400863, 0.271191, 0.228758, 0.198777, 0.156327, + -0.0707191, -0.178816, -0.167941, 0.323581, -0.0157677, 0.390985, 0.361745, -0.183938, -0.0487849, + -0.27895, -0.442507, -0.170435, -0.0345637, -0.031372, -0.193189 + ], + "dims": [64], + "type": "float32" + } + ] + }, + { + "name": "T[2,4,4,4,3]", + "inputs": [ + { + "data": [ + 2.02384, 0.168795, -0.523222, -0.935186, 0.740422, 1.17608, 0.488569, -0.377683, -0.53195, -0.513934, + 0.432598, 0.914993, -1.27082, -2.07414, 2.69308, -0.131913, -2.85251, -0.517211, -1.806, 0.273531, + 0.472273, -0.37904, 0.0532606, -0.464725, 0.667796, 1.31052, -0.929768, -1.14826, -0.769382, -0.631145, + 1.2522, 0.9976, 0.919709, 0.0300339, 0.850536, -0.27391, 2.4758, -1.53812, 1.76689, 1.55511, -0.00496016, + 0.894897, 0.385341, 0.931242, 0.235798, 1.46645, 0.0517056, 1.2544, -1.09355, -0.497829, 0.858985, + -2.56309, 0.275869, -0.139707, 0.976015, 0.860001, 0.354544, -1.47036, 1.23747, 0.200878, 0.89486, + 0.179686, 0.353255, 0.580989, 1.5914, 0.0722632, -1.12418, 0.740327, -1.56074, -0.339189, 0.798208, + 1.03685, 1.3314, 2.12478, 1.73434, 0.418893, 1.74205, 0.193269, -0.301401, -0.322054, -0.864609, -1.2983, + -0.0112451, 0.842739, -0.839063, 0.204525, -0.372717, 0.170261, -0.431252, 0.584484, 1.15486, -1.3114, + 0.16315, -0.255735, 0.186204, 1.60674, -0.589851, 0.780569, -0.0611289, -0.416289, -1.42994, -1.24544, + -0.952648, 1.63344, 1.33361, -0.360487, -0.00839034, -0.961942, 0.253287, -0.187035, -0.15732, 0.437195, + 1.8406, -0.348637, 0.32023, 1.32053, 0.361842, 0.209606, -0.636963, 0.7386, -0.279519, 0.408944, 0.517256, + -0.546527, -1.50846, 1.20406, 0.265286, -1.2076, -2.07277, -1.07383, -0.129118, -1.01983, -1.65879, + -0.0441307, -1.9163, 1.1222, 1.47558, 0.239934, 0.946612, 1.07251, 0.177979, 0.822549, 1.05295, 0.464564, + 0.64689, -0.420297, 0.988822, -0.292639, -1.13402, 0.284607, -0.73995, -0.524053, -1.56099, -0.694949, + 3.20754, -0.429143, 1.33899, -0.588935, 0.111043, -0.0652476, -0.527549, -0.0853688, 1.61791, 0.591928, + -0.319176, 1.49692, -1.10529, -0.279777, -0.761145, 0.520412, 0.520971, -0.201874, 0.19404, -1.078, + -1.15431, -1.21229, -0.670242, -1.83111, -0.399594, 0.065652, -0.705267, -0.280935, 0.468538, -0.143026, + -0.363324, -0.825062, -0.129819, -0.00804771, 0.370068, -0.799425, 1.43102, 1.68751, -1.16928, -1.27073, + -1.73198, -0.411782, -0.263788, -1.1062, 1.61624, -1.7702, 0.951285, -0.973004, 0.203263, -0.713336, + 2.64703, 1.34631, 1.61586, -0.220014, 1.11914, 1.96514, -1.43954, -2.04911, 0.002603, -0.018692, + -0.804137, 0.0953297, 1.34982, 0.466763, 0.949256, -0.95197, 2.18386, -1.76552, -1.72586, 1.4689, + 0.372816, 1.32725, 0.898297, -0.781229, 0.280984, -0.648948, 1.50532, 0.00847463, 0.252202, 1.28462, + 0.512869, 1.12501, 1.31116, 0.0378154, -0.204563, 0.731908, 0.13898, 0.124608, 1.54835, 0.35758, 0.377214, + 0.371081, -0.084558, 0.894327, 0.409244, 1.04045, -0.249118, -0.106938, -1.79933, 0.709188, -1.79396, + 1.3002, 0.999397, -1.61198, 0.390457, -1.4079, -0.80869, 1.22267, 0.193873, -1.10381, 0.959344, 0.657753, + 1.1872, -0.964296, -0.709732, -0.832439, -0.0935597, 1.09897, 0.0755941, 0.288953, -0.145793, -1.09553, + -0.158046, 0.779199, 0.960059, 0.532672, 0.88378, 1.44252, -0.500988, -1.2676, -0.196482, 0.25187, + 1.15709, -1.07364, -2.14384, 0.62295, 0.165547, -0.633315, -0.370894, 0.630078, 1.24612, -0.103268, + 1.56569, -1.41525, -1.55949, -0.669592, 0.36494, -0.470747, 1.15974, -0.00714732, 0.100394, 0.0953399, + -0.608963, 0.422334, -0.202313, 0.508496, -0.0685312, 0.812631, 0.995365, -0.434488, -0.318567, 1.21159, + -0.568974, -0.16644, -0.169055, -0.256987, 0.887062, -0.968783, 2.01276, -0.0264821, 1.52779, -0.923322, + -0.740725, -0.082381, -0.613144, 0.0797577, 2.2049, 1.50676, -1.1037, 0.928655, 0.65756, 0.90236, + 0.120245, 1.20524, 1.13427, 0.911429, 1.10395, 0.364186, -0.885258, -0.975241, -2.01043, -1.2072, 2.44035, + -0.415748, 0.770694, 1.08276, 0.116046, 2.36621, 0.330393, 0.369949, 1.08456, -0.508918, 0.317886, + -1.60069, -1.25545, 0.530332, 0.0345025, 0.189815, 1.48341, 0.359559, -0.156263, 0.74666, -0.785411, + -0.960866, -1.64142, 0.466532, 1.0859, 0.22569, -0.78543, -0.674478, 1.18015, 0.024879, 2.76743, 1.31827, + 1.59337, 1.21399, -1.33904, 1.13718, 1.71666, -0.101125 + ], + "dims": [2, 4, 4, 4, 3], + "type": "float32" + }, + { + "data": [0.241661, 0.960798, 0.474727], + "dims": [3], + "type": "float32" + }, + { + "data": [0, 0, 0], + "dims": [3], + "type": "float32" + }, + { + "data": [0, 0, 0], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 1, 1], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.489082, 0.162177, -0.248386, -0.225997, 0.711393, 0.558316, 0.118068, -0.362876, -0.25253, -0.124197, + 0.415637, 0.43437, -0.307105, -1.99282, 1.27847, -0.031878, -2.74067, -0.245533, -0.436439, 0.262807, + 0.2242, -0.0915989, 0.0511725, -0.220617, 0.16138, 1.25914, -0.441384, -0.277489, -0.739217, -0.29962, + 0.302606, 0.958488, 0.436609, 0.007258, 0.817189, -0.130032, 0.598301, -1.47782, 0.838785, 0.375807, + -0.00476569, 0.424829, 0.0931215, 0.894731, 0.111939, 0.354382, 0.0496784, 0.595496, -0.264267, -0.478311, + 0.407781, -0.619395, 0.265053, -0.0663221, 0.235864, 0.826283, 0.168311, -0.355328, 1.18895, 0.0953618, + 0.216252, 0.172641, 0.167699, 0.140402, 1.52901, 0.0343051, -0.271669, 0.711301, -0.74092, -0.0819684, + 0.766913, 0.492219, 0.321747, 2.04147, 0.823334, 0.10123, 1.67375, 0.0917494, -0.0728365, -0.309427, + -0.410451, -0.313746, -0.0108042, 0.400069, -0.202768, 0.196507, -0.176938, 0.0411454, -0.414344, + 0.277469, 0.279085, -1.25999, 0.0774512, -0.0618009, 0.178903, 0.762761, -0.142543, 0.749965, -0.0290194, + -0.1006, -1.37387, -0.59124, -0.230217, 1.5694, 0.6331, -0.0871152, -0.00806138, -0.456657, 0.0612094, + -0.179702, -0.0746837, 0.105652, 1.76844, -0.165507, 0.0773867, 1.26875, 0.171775, 0.0506533, -0.61199, + 0.350631, -0.0675486, 0.392911, 0.245554, -0.132074, -1.44932, 0.571595, 0.064109, -1.16025, -0.983996, + -0.259501, -0.124055, -0.484139, -0.400863, -0.0424004, -0.909715, 0.271191, 1.41773, 0.113902, 0.228758, + 1.03046, 0.0844908, 0.198777, 1.01167, 0.22054, 0.156327, -0.403818, 0.469418, -0.0707191, -1.08956, + 0.13511, -0.178816, -0.503507, -0.741041, -0.167941, 3.08178, -0.203725, 0.323581, -0.565845, 0.0527148, + -0.0157677, -0.506866, -0.0405267, 0.390985, 0.56872, -0.151521, 0.361745, -1.06196, -0.132817, -0.183938, + 0.500008, 0.247318, -0.0487849, 0.186433, -0.511752, -0.27895, -1.16476, -0.31818, -0.442507, -0.383928, + 0.0311666, -0.170435, -0.269921, 0.222426, -0.0345637, -0.349079, -0.391677, -0.031372, -0.00773219, + 0.17568, -0.193189, 1.37492, 0.801104, -0.282569, -1.22091, -0.822216, -0.0995112, -0.253445, -0.525139, + 0.39058, -1.7008, 0.451599, -0.235136, 0.195294, -0.338638, 0.639682, 1.29353, 0.767087, -0.0531687, + 1.07526, 0.932899, -0.347878, -1.96877, 0.00123571, -0.0045171, -0.772609, 0.0452554, 0.326198, 0.448463, + 0.450635, -0.230053, 2.09824, -0.838136, -0.41707, 1.4113, 0.176985, 0.320744, 0.863078, -0.370868, + 0.0679025, -0.623505, 0.714614, 0.00204798, 0.242314, 0.60984, 0.12394, 1.0809, 0.622438, 0.00913847, + -0.196543, 0.347455, 0.0335859, 0.119722, 0.73504, 0.0864127, 0.362425, 0.176161, -0.0204343, 0.859263, + 0.194278, 0.251436, -0.239351, -0.0507662, -0.434827, 0.681383, -0.851639, 0.314206, 0.960214, -0.765246, + 0.0943579, -1.3527, -0.383905, 0.295471, 0.186272, -0.524005, 0.231835, 0.631964, 0.563593, -0.233032, + -0.681905, -0.395179, -0.0226096, 1.05588, 0.0358864, 0.0698283, -0.140077, -0.520076, -0.0381934, + 0.748649, 0.455763, 0.128725, 0.84913, 0.684801, -0.121069, -1.2179, -0.093275, 0.060867, 1.11172, + -0.509682, -0.51808, 0.598526, 0.0785892, -0.153047, -0.356353, 0.299113, 0.301137, -0.099219, 0.743272, + -0.342009, -1.49835, -0.317872, 0.0881915, -0.452291, 0.550556, -0.00172722, 0.0964582, 0.0452602, + -0.147162, 0.405776, -0.0960432, 0.122883, -0.0658444, 0.385776, 0.24054, -0.417454, -0.151232, 0.292792, + -0.546667, -0.079013, -0.0408538, -0.246911, 0.42111, -0.234116, 1.93385, -0.0125717, 0.369206, -0.887121, + -0.35164, -0.0199082, -0.589104, 0.0378629, 0.532835, 1.44769, -0.523955, 0.224419, 0.631779, 0.428372, + 0.0290583, 1.15798, 0.538468, 0.220256, 1.06067, 0.172888, -0.213931, -0.937005, -0.954402, -0.291733, + 2.34467, -0.197366, 0.186246, 1.04031, 0.0550898, 0.571817, 0.31744, 0.175624, 0.262095, -0.488965, + 0.150908, -0.386822, -1.20623, 0.251761, 0.00833788, 0.182373, 0.704209, 0.086891, -0.150136, 0.354458, + -0.189802, -0.923194, -0.779221, 0.112742, 1.04332, 0.107141, -0.189807, -0.648034, 0.560244, 0.00601226, + 2.65893, 0.625814, 0.385054, 1.1664, -0.635675, 0.274811, 1.64935, -0.0480064 + ], + "dims": [2, 4, 4, 4, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "BatchNormalization non-spatial mode", + "operator": "BatchNormalization", + "opset": { "domain": "", "version": 7 }, + "attributes": [{ "name": "spatial", "data": 0, "type": "int" }], + "cases": [ + { + "name": "T[3,1,2]", + "inputs": [ + { + "data": [0.2134, 0.32434, 0.5644, 0.3234, 0.4545, 0.3445], + "dims": [3, 1, 2], + "type": "float32" + }, + { + "data": [0.5, 0.6], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.2, 0.1], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.034, 0.342], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [1, 1], + "dims": [1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.2897, 0.089404, 0.4652, 0.08884, 0.41025, 0.1015], + "dims": [3, 1, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "BatchNormalization non-spatial mode - NHWC", + "operator": "BatchNormalization", + "opset": { "domain": "com.ms.internal.nhwc", "version": 7 }, + "attributes": [{ "name": "spatial", "data": 0, "type": "int" }], + "cases": [ + { + "name": "T[3,2,1]", + "inputs": [ + { + "data": [0.2134, 0.32434, 0.5644, 0.3234, 0.4545, 0.3445], + "dims": [3, 2, 1], + "type": "float32" + }, + { + "data": [0.5, 0.6], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.2, 0.1], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.034, 0.342], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [1, 1], + "dims": [1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.2897, 0.089404, 0.4652, 0.08884, 0.41025, 0.1015], + "dims": [3, 2, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/conv.jsonc b/js/web/test/data/ops/conv.jsonc index 219e15eb4648f..2e8eaaba191d0 100644 --- a/js/web/test/data/ops/conv.jsonc +++ b/js/web/test/data/ops/conv.jsonc @@ -126,7 +126,7 @@ ] }, { - "name": "conv with bias addition C", + "name": "conv with bias addition C - NHWC", "operator": "Conv", "inputShapeDefinitions": "rankOnly", "opset": { "domain": "", "version": 17 }, @@ -158,6 +158,36 @@ "type": "float32" } ] + }, + { + "name": "inChannel = 3, outChannel = 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 1, 2, 3, 4, 5, 6, 7, 8 + ], + "dims": [4, 3, 2, 2], + "type": "float32" + }, + { + "data": [5, 6, 7, 8], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [360, 334, 271, 323, 909, 963, 1024, 1028, 683, 655, 576, 650, 473, 508, 570, 677], + "dims": [1, 4, 2, 2], + "type": "float32" + } + ] } ] }, diff --git a/js/web/test/data/ops/cumsum.jsonc b/js/web/test/data/ops/cumsum.jsonc new file mode 100644 index 0000000000000..b3173afb695ea --- /dev/null +++ b/js/web/test/data/ops/cumsum.jsonc @@ -0,0 +1,1362 @@ +[ + { + "name": "CumSum", + "operator": "CumSum", + "attributes": [ + { "name": "exclusive", "data": 0, "type": "int" }, + { "name": "reverse", "data": 0, "type": "int" } + ], + "opset": { + "domain": "", + "version": 11 + }, + "cases": [ + { + "name": "CumSum 1-D; axis = 0; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 3, 6, 10, 15], + "dims": [5], + "type": "float32" + } + ] + }, + { + "name": "CumSum 1-D; axis = -1; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [-1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 3, 6, 10, 15], + "dims": [5], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = 0; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 5, 7, 9], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = -1; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [-1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 3, 6, 4, 9, 15], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = 1; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 3, 6, 4, 9, 15], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = -2; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [-2], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 5, 7, 9], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (3x3); axis = 0; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 5, 7, 9, 12, 15, 18], + "dims": [3, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (3x3); axis = 1; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 3, 6, 4, 9, 15, 7, 15, 24], + "dims": [3, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = 0; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 6, 8, 10, 12], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = 1; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 4, 6, 5, 6, 12, 14], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = -1; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [-1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 3, 3, 7, 5, 11, 7, 15], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = 2; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [2], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 3, 3, 7, 5, 11, 7, 15], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = -2; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [-2], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 4, 6, 5, 6, 12, 14], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = -3; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [-3], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 6, 8, 10, 12], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "CumSum", + "operator": "CumSum", + "attributes": [ + { "name": "exclusive", "data": 1, "type": "int" }, + { "name": "reverse", "data": 0, "type": "int" } + ], + "opset": { + "domain": "", + "version": 11 + }, + "cases": [ + { + "name": "CumSum 1-D; axis = 0; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 1, 3, 6, 10], + "dims": [5], + "type": "float32" + } + ] + }, + { + "name": "CumSum 1-D; axis = -1; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [-1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 1, 3, 6, 10], + "dims": [5], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = 0; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 0, 0, 1, 2, 3], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = -1; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [-1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 1, 3, 0, 4, 9], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = 1; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 1, 3, 0, 4, 9], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = -2", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [-2], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 0, 0, 1, 2, 3], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (3x3); axis = 0; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 0, 0, 1, 2, 3, 5, 7, 9], + "dims": [3, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (3x3); axis = 1; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 1, 3, 0, 4, 9, 0, 7, 15], + "dims": [3, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = 0; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 0, 0, 0, 1, 2, 3, 4], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = 1; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 0, 1, 2, 0, 0, 5, 6], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = -1; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [-1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 1, 0, 3, 0, 5, 0, 7], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = 2; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [2], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 1, 0, 3, 0, 5, 0, 7], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = -2; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [-2], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 0, 1, 2, 0, 0, 5, 6], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = -3; exclusive = 1, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [-3], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 0, 0, 0, 1, 2, 3, 4], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "CumSum", + "operator": "CumSum", + "attributes": [ + { "name": "exclusive", "data": 0, "type": "int" }, + { "name": "reverse", "data": 1, "type": "int" } + ], + "opset": { + "domain": "", + "version": 11 + }, + "cases": [ + { + "name": "CumSum 1-D; axis = 0; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [15, 14, 12, 9, 5], + "dims": [5], + "type": "float32" + } + ] + }, + { + "name": "CumSum 1-D; axis = -1; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [-1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [15, 14, 12, 9, 5], + "dims": [5], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = 0; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [5, 7, 9, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = -1; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [-1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [6, 5, 3, 15, 11, 6], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = 1; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [6, 5, 3, 15, 11, 6], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = -2; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [-2], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [5, 7, 9, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (3x3); axis = 0; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [12, 15, 18, 11, 13, 15, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (3x3); axis = 1; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [6, 5, 3, 15, 11, 6, 24, 17, 9], + "dims": [3, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = 0; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [6, 8, 10, 12, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = 1; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [4, 6, 3, 4, 12, 14, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = -1; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [-1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [3, 2, 7, 4, 11, 6, 15, 8], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = 2; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [2], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [3, 2, 7, 4, 11, 6, 15, 8], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = -2; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [-2], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [4, 6, 3, 4, 12, 14, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = -3; exclusive = 0, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [-3], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [6, 8, 10, 12, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "CumSum", + "operator": "CumSum", + "attributes": [ + { "name": "exclusive", "data": 1, "type": "int" }, + { "name": "reverse", "data": 1, "type": "int" } + ], + "opset": { + "domain": "", + "version": 11 + }, + "cases": [ + { + "name": "CumSum 1-D; axis = 0; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [14, 12, 9, 5, 0], + "dims": [5], + "type": "float32" + } + ] + }, + { + "name": "CumSum 1-D; axis = -1; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [-1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [14, 12, 9, 5, 0], + "dims": [5], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = 0; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [4, 5, 6, 0, 0, 0], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = -1; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [-1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [5, 3, 0, 11, 6, 0], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = 1; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [5, 3, 0, 11, 6, 0], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (2x3); axis = -2; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [-2], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [4, 5, 6, 0, 0, 0], + "dims": [2, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (3x3); axis = 0; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [11, 13, 15, 7, 8, 9, 0, 0, 0], + "dims": [3, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 2-D (3x3); axis = 1; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [5, 3, 0, 11, 6, 0, 17, 9, 0], + "dims": [3, 3], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = 0; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [5, 6, 7, 8, 0, 0, 0, 0], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = 1; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [3, 4, 0, 0, 7, 8, 0, 0], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = -1; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [-1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [2, 0, 4, 0, 6, 0, 8, 0], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = 2; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [2], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [2, 0, 4, 0, 6, 0, 8, 0], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = -2; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [-2], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [3, 4, 0, 0, 7, 8, 0, 0], + "dims": [2, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "CumSum 3-D; axis = -3; exclusive = 1, reverse = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "float32" + }, + { + "data": [-3], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [5, 6, 7, 8, 0, 0, 0, 0], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "CumSum", + "operator": "CumSum", + "attributes": [ + { "name": "exclusive", "data": 0, "type": "int" }, + { "name": "reverse", "data": 0, "type": "int" } + ], + "opset": { + "domain": "", + "version": 11 + }, + "cases": [ + { + "name": "CumSum 5-D; axis = 0; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [1, 1, 1, 1, 5], + "type": "float32" + }, + { + "data": [4], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 3, 6, 10, 15], + "dims": [1, 1, 1, 1, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "CumSum", + "operator": "CumSum", + "attributes": [ + { "name": "exclusive", "data": 0, "type": "int" }, + { "name": "reverse", "data": 0, "type": "int" } + ], + "opset": { + "domain": "", + "version": 11 + }, + "cases": [ + { + "name": "CumSum int32; axis = 0; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [1, 1, 1, 1, 5], + "type": "int32" + }, + { + "data": [4], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 3, 6, 10, 15], + "dims": [1, 1, 1, 1, 5], + "type": "int32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/einsum.jsonc b/js/web/test/data/ops/einsum.jsonc index baf30cf982148..45bba6a121bd1 100644 --- a/js/web/test/data/ops/einsum.jsonc +++ b/js/web/test/data/ops/einsum.jsonc @@ -171,7 +171,7 @@ ], "cases": [ { - "name": "Diagonal elementwise multiplication", + "name": "Diagonal elements dot product", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], @@ -210,7 +210,7 @@ ], "cases": [ { - "name": "Dotproduct", + "name": "diagonal elements multiplication", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], @@ -233,6 +233,240 @@ } ] }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ij,ij -> ij", + "type": "string" + } + ], + "cases": [ + { + "name": "Elementwise multiplication", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1, 0, 0, 0, 1, 0, 0, 0, 1], + "dims": [3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 0, 0, 5, 0, 0, 0, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,i", + "type": "string" + } + ], + "cases": [ + { + "name": "Dot product/scalar product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 1, 1], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [6], + "dims": [], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,j->ij", + "type": "string" + } + ], + "cases": [ + { + "name": "outer product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 2, 4, 6, 3, 6, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ij,ij -> ij", + "type": "string" + } + ], + "cases": [ + { + "name": "Elementwise multiplication", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1, 0, 0, 0, 1, 0, 0, 0, 1], + "dims": [3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 0, 0, 5, 0, 0, 0, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,i", + "type": "string" + } + ], + "cases": [ + { + "name": "Dot product/scalar product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 1, 1], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [6], + "dims": [], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,j->ij", + "type": "string" + } + ], + "cases": [ + { + "name": "outer product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 2, 4, 6, 3, 6, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, { "name": "einsum", "operator": "Einsum", @@ -249,7 +483,7 @@ ], "cases": [ { - "name": "Multiply", + "name": "Multiply (2,3) X (3,4) -> (2,4)", "inputs": [ { "data": [1, 2, 3, 4, 5, 6], @@ -269,6 +503,28 @@ "type": "float32" } ] + }, + { + "name": "Multiply (2,6) X (6,4) -> (2,4)", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + "dims": [2, 6], + "type": "float32" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [6, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [220, 235, 250, 265, 580, 631, 682, 733], + "dims": [2, 4], + "type": "float32" + } + ] } ] }, @@ -631,5 +887,73 @@ ] } ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ijk->ikj", + "type": "string" + } + ], + "cases": [ + { + "name": "Transpose with 3 dims", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [1, 2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 4, 2, 5, 3, 6], + "dims": [1, 3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "...ij->...ji", + "type": "string" + } + ], + "cases": [ + { + "name": "Transpose with ellipsis with input/output dims > 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [1, 1, 1, 2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 4, 2, 5, 3, 6], + "dims": [1, 1, 1, 3, 2], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/expand.jsonc b/js/web/test/data/ops/expand.jsonc index 460122b4e085c..22bc04d558d98 100644 --- a/js/web/test/data/ops/expand.jsonc +++ b/js/web/test/data/ops/expand.jsonc @@ -85,5 +85,107 @@ ] } ] + }, + { + "name": "Expand 5D - float32", + "operator": "Expand", + "attributes": [], + "cases": [ + { + "name": "Expand 5 - float32", + "inputs": [ + { + "data": [1], + "dims": [1, 1, 1, 1, 1], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 6], + "dims": [5], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 1, 1, 1, 1, 1], + "dims": [1, 1, 1, 1, 6], + "type": "float32" + } + ] + }, + { + "name": "Expand 5 - shape < input.size()", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 1, 1, 2, 6], + "type": "float32" + }, + { + "data": [2, 1, 6], + "dims": [3], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 1, 2, 2, 6], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Expand - bool", + "operator": "Expand", + "attributes": [], + "cases": [ + { + "name": "Expand - last dim is divisible by 4", + "inputs": [ + { + "data": [true, false, false, true], + "dims": [4], + "type": "bool" + }, + { + "data": [2, 4], + "dims": [2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [true, false, false, true, true, false, false, true], + "dims": [2, 4], + "type": "bool" + } + ] + }, + { + "name": "Expand - last dim is not divisible by 4", + "inputs": [ + { + "data": [true, false, false, true, true, true, false, false, false, true, true, true], + "dims": [2, 6], + "type": "bool" + }, + { + "data": [2, 1], + "dims": [2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [true, false, false, true, true, true, false, false, false, true, true, true], + "dims": [2, 6], + "type": "bool" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/gather.jsonc b/js/web/test/data/ops/gather.jsonc index 3b1b0e3821832..0be077d237b88 100644 --- a/js/web/test/data/ops/gather.jsonc +++ b/js/web/test/data/ops/gather.jsonc @@ -93,5 +93,34 @@ ] } ] + }, + { + "name": "Gather - bool", + "operator": "Gather", + "attributes": [], + "cases": [ + { + "name": "data[2,4] indices[1]", + "inputs": [ + { + "data": [true, false, false, true, false, false, true, true], + "dims": [2, 4], + "type": "bool" + }, + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [false, false, true, true], + "dims": [1, 4], + "type": "bool" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/global-average-pool.jsonc b/js/web/test/data/ops/global-average-pool.jsonc index fdf3a8fe1e7a2..17aa061841b2c 100644 --- a/js/web/test/data/ops/global-average-pool.jsonc +++ b/js/web/test/data/ops/global-average-pool.jsonc @@ -61,6 +61,29 @@ "type": "float32" } ] + }, + { + "name": "T[1,3,2,2,2] T[1,3,1,1,1]", + "inputs": [ + { + "data": [ + 1.764052391052246, 0.40015721321105957, 0.978738009929657, 2.2408931255340576, 1.8675580024719238, + -0.9772778749465942, 0.9500884413719177, -0.15135720372200012, -0.10321885347366333, 0.4105985164642334, + 0.14404356479644775, 1.4542734622955322, 0.7610377073287964, 0.12167501449584961, 0.44386324286460876, + 0.3336743414402008, 1.4940791130065918, -0.2051582634449005, 0.3130677044391632, -0.8540957570075989, + -2.5529897212982178, 0.653618574142456, 0.8644362092018127, -0.7421650290489197 + ], + "dims": [1, 3, 2, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.8841065168380737, 0.4457433819770813, -0.12865088880062103], + "dims": [1, 3, 1, 1, 1], + "type": "float32" + } + ] } ] } diff --git a/js/web/test/data/ops/multi-head-attention.jsonc b/js/web/test/data/ops/multi-head-attention.jsonc new file mode 100644 index 0000000000000..05687bd482e24 --- /dev/null +++ b/js/web/test/data/ops/multi-head-attention.jsonc @@ -0,0 +1,194 @@ +[ + { + "name": "MultiHeadAttention Basic, one head", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 4.973228454589844, 5.973228454589844, 6.973228454589844, 7.973228454589844, 4.999990940093994, + 5.999990940093994, 6.999990940093994, 7.999990940093994 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 4.571832656860352, 5.571832656860352, 6.971858501434326, 7.971858501434326, 4.998325824737549, + 5.998325824737549, 6.999900817871094, 7.999900817871094 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic with bias", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4], + "dims": [12], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 5.943336009979248, 7.94333553314209, 9.999799728393555, 11.999798774719238, 5.9997992515563965, + 7.9997992515563965, 10, 11.999999046325684 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention two heads", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 8.99963665008545, 9.99963665008545, 10.99963665008545, 11.999635696411133, 13, 14, 15, 16, 9, 10, 11, 12, + 13, 14, 15, 16 + ], + "dims": [1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention two heads", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[1]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 1, 8], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 1, 8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 8], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/slice.jsonc b/js/web/test/data/ops/slice.jsonc index 9c90817a80c36..beef154a29932 100644 --- a/js/web/test/data/ops/slice.jsonc +++ b/js/web/test/data/ops/slice.jsonc @@ -21,6 +21,29 @@ } ] }, + { + "name": "Slice float32 with input[0] dim > 4", + "operator": "Slice", + "attributes": [], + "cases": [ + { + "name": "T[1, 1, 1, 1, 5] T[1] T[1] T[1] (float32)", + "inputs": [ + { + "data": [ + 0.3964604139328003, -0.8916832804679871, -1.6578896045684814, 1.960708737373352, 1.181204915046692 + ], + "dims": [1, 1, 1, 1, 5], + "type": "float32" + }, + { "data": [3], "dims": [1], "type": "int64" }, + { "data": [4], "dims": [1], "type": "int64" }, + { "data": [4], "dims": [1], "type": "int64" } + ], + "outputs": [{ "data": [1.960708737373352], "dims": [1, 1, 1, 1, 1], "type": "float32" }] + } + ] + }, { "name": "Slice int32", "operator": "Slice", diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index c80f0b04a9abc..a313adef7151b 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1336,6 +1336,8 @@ "add_int32.jsonc", //"and.jsonc", "asin.jsonc", + "attention.jsonc", + "batch-norm.jsonc", "bias-add.jsonc", "bias-split-gelu.jsonc", "ceil.jsonc", @@ -1362,6 +1364,7 @@ "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", + "multi-head-attention.jsonc", //"neg.jsonc", "neg-int32.jsonc", "not.jsonc", diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index 24ab0694b32b8..9bd0ec1425f95 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -56,7 +56,7 @@ if (options.globalEnvFlags) { ort.env.wasm.initTimeout = flags.wasm.initTimeout; } if (flags.webgpu?.profilingMode !== undefined) { - ort.env.webgpu.profilingMode = flags.webgpu.profilingMode; + ort.env.webgpu.profiling = {mode: flags.webgpu.profilingMode}; } if (flags.webgpu?.validateInputContent !== undefined) { ort.env.webgpu.validateInputContent = flags.webgpu.validateInputContent; diff --git a/js/web/tsconfig.json b/js/web/tsconfig.json index d60d746e9328d..80d0cd0642b80 100644 --- a/js/web/tsconfig.json +++ b/js/web/tsconfig.json @@ -6,5 +6,5 @@ "typeRoots": ["./node_modules/@webgpu/types", "./node_modules/@types", "../node_modules/@types"] }, "include": ["lib", "test"], - "exclude": ["lib/wasm/proxy-worker"] + "exclude": ["lib/wasm/proxy-worker", "test/ort.test.js", "test/ort.test.min.js"] } diff --git a/objectivec/include/ort_env.h b/objectivec/include/ort_env.h index 8456b57bfa402..67db76668b3bb 100644 --- a/objectivec/include/ort_env.h +++ b/objectivec/include/ort_env.h @@ -24,6 +24,9 @@ NSString* _Nullable ORTVersion(void); /** * The ORT environment. + * It maintains shared state including the default logger. + * + * @note One ORTEnv should be created before and destroyed after other ORT API usage. */ @interface ORTEnv : NSObject diff --git a/objectivec/include/ort_training_session.h b/objectivec/include/ort_training_session.h index 15c0137817ae2..2ad4fed93c331 100644 --- a/objectivec/include/ort_training_session.h +++ b/objectivec/include/ort_training_session.h @@ -39,7 +39,7 @@ NS_ASSUME_NONNULL_BEGIN * session which will be moved to the device specified in the session option if needed. * * @param env The `ORTEnv` instance to use for the training session. - * @param sessionOptions The `ORTSessionOptions` to use for the training session. + * @param sessionOptions The optional `ORTSessionOptions` to use for the training session. * @param checkpoint Training states that are used as a starting point for training. * @param trainModelPath The path to the training onnx model. * @param evalModelPath The path to the evaluation onnx model. @@ -52,7 +52,7 @@ NS_ASSUME_NONNULL_BEGIN * keeps a strong (owning) pointer to the checkpoint state. */ - (nullable instancetype)initWithEnv:(ORTEnv*)env - sessionOptions:(ORTSessionOptions*)sessionOptions + sessionOptions:(nullable ORTSessionOptions*)sessionOptions checkpoint:(ORTCheckpoint*)checkpoint trainModelPath:(NSString*)trainModelPath evalModelPath:(nullable NSString*)evalModelPath diff --git a/objectivec/ort_session.mm b/objectivec/ort_session.mm index d27c3e2cefcfb..87288bd1e9dc7 100644 --- a/objectivec/ort_session.mm +++ b/objectivec/ort_session.mm @@ -23,6 +23,7 @@ NS_ASSUME_NONNULL_BEGIN @implementation ORTSession { + ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does std::optional _session; } @@ -44,6 +45,7 @@ - (nullable instancetype)initWithEnv:(ORTEnv*)env } } + _env = env; _session = Ort::Session{[env CXXAPIOrtEnv], path.UTF8String, [sessionOptions CXXAPIOrtSessionOptions]}; diff --git a/objectivec/ort_training_session.mm b/objectivec/ort_training_session.mm index 285151b412bf0..5387bfda6d411 100644 --- a/objectivec/ort_training_session.mm +++ b/objectivec/ort_training_session.mm @@ -19,8 +19,9 @@ NS_ASSUME_NONNULL_BEGIN @implementation ORTTrainingSession { - std::optional _session; + ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does ORTCheckpoint* _checkpoint; + std::optional _session; } - (Ort::TrainingSession&)CXXAPIOrtTrainingSession { @@ -28,7 +29,7 @@ @implementation ORTTrainingSession { } - (nullable instancetype)initWithEnv:(ORTEnv*)env - sessionOptions:(ORTSessionOptions*)sessionOptions + sessionOptions:(nullable ORTSessionOptions*)sessionOptions checkpoint:(ORTCheckpoint*)checkpoint trainModelPath:(NSString*)trainModelPath evalModelPath:(nullable NSString*)evalModelPath @@ -39,9 +40,17 @@ - (nullable instancetype)initWithEnv:(ORTEnv*)env } try { + if (!sessionOptions) { + sessionOptions = [[ORTSessionOptions alloc] initWithError:error]; + if (!sessionOptions) { + return nil; + } + } + std::optional evalPath = utils::toStdOptionalString(evalModelPath); std::optional optimizerPath = utils::toStdOptionalString(optimizerModelPath); + _env = env; _checkpoint = checkpoint; _session = Ort::TrainingSession{ [env CXXAPIOrtEnv], @@ -50,6 +59,7 @@ - (nullable instancetype)initWithEnv:(ORTEnv*)env trainModelPath.UTF8String, evalPath, optimizerPath}; + return self; } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) diff --git a/objectivec/test/ort_session_test.mm b/objectivec/test/ort_session_test.mm index f00f5db2f995f..508289f7bc748 100644 --- a/objectivec/test/ort_session_test.mm +++ b/objectivec/test/ort_session_test.mm @@ -295,6 +295,32 @@ - (void)testStringInputs { XCTAssertTrue([stringData isEqualToArray:outputStringData]); } +- (void)testKeepORTEnvReference { + ORTEnv* __weak envWeak = _ortEnv; + // Remove sole strong reference to the ORTEnv created in setUp. + _ortEnv = nil; + // There should be no more strong references to it. + XCTAssertNil(envWeak); + + // Create a new ORTEnv. + NSError* err = nil; + ORTEnv* env = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning + error:&err]; + ORTAssertNullableResultSuccessful(env, err); + + ORTSession* session = [[ORTSession alloc] initWithEnv:env + modelPath:[ORTSessionTest getAddModelPath] + sessionOptions:[ORTSessionTest makeSessionOptions] + error:&err]; + ORTAssertNullableResultSuccessful(session, err); + + envWeak = env; + // Remove strong reference to the ORTEnv passed to the ORTSession initializer. + env = nil; + // ORTSession should keep a strong reference to it. + XCTAssertNotNil(envWeak); +} + @end NS_ASSUME_NONNULL_END diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 0ed7d887fc5e5..57219c50f39aa 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -61,7 +61,6 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import OrtValue # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import SparseTensor # noqa: F401 -from onnxruntime.capi.training import * # noqa: F403 # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end try: # noqa: SIM105 diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index b693b58c7c40a..a7f83469a768d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -96,9 +96,9 @@ struct GroupQueryAttentionParameters { int kv_num_heads; int num_splits; // number of splits for splitkv bool is_unidirectional; // causal + int local_window_size; bool kv_share_buffer; - bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor - bool left_padding; // copies last token to last index if true + bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 4a266af789250..47f462d75fcc4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -63,6 +63,16 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int head_size = parameters.head_size; const int position_ids_format = parameters.position_ids_format; const int half_head_size = head_size / 2; + // Default input tensor shape is [batch, seq_len, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (parameters.transposed) { + // Transposed input tensor shape is [batch, num_heads, seq_len, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -76,11 +86,10 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int s = static_cast((ptr / num_heads) % sequence_length); const int n = static_cast(ptr % num_heads); - const int block_offset = b * sequence_length * num_heads + s * num_heads + n; - const int data_offset = block_offset * head_size; + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - const T* input_data = input_src + data_offset; - T* output_data = output_dest + data_offset; + const T* input_data = input_src + block_offset; + T* output_data = output_dest + block_offset; // Cache is (M, H/2) const int position_id = (position_ids_format == 0) diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index cf8080800e072..7b2e8289f7b06 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -18,6 +18,7 @@ struct RotaryParameters { int num_heads; // num_heads = hidden_size / head_size int max_sequence_length; // Sequence length used by cos/sin cache int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) }; template @@ -33,8 +34,8 @@ Status CheckInputs(const T* input, // Check input const auto& input_dims = input->Shape().GetDims(); - if (input_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ", + if (input_dims.size() != 3 && input_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ", input_dims.size()); } // Check position_ids @@ -63,6 +64,14 @@ Status CheckInputs(const T* input, int batch_size = static_cast(input_dims[0]); int sequence_length = static_cast(input_dims[1]); int hidden_size = static_cast(input_dims[2]); + + bool transposed = false; + if (input_dims.size() == 4) { + // input is [batch, num_heads, seq, head_size] + sequence_length = static_cast(input_dims[2]); + hidden_size = static_cast(input_dims[1]) * static_cast(input_dims[3]); + transposed = true; + } int max_sequence_length = static_cast(cos_cache_dims[0]); int head_size = static_cast(cos_cache_dims[1]) * 2; int num_heads = hidden_size / head_size; @@ -111,6 +120,7 @@ Status CheckInputs(const T* input, output_parameters->num_heads = num_heads; output_parameters->max_sequence_length = max_sequence_length; output_parameters->position_ids_format = position_ids_format; + output_parameters->transposed = transposed; } return Status::OK(); @@ -118,4 +128,4 @@ Status CheckInputs(const T* input, } // namespace rotary_embedding_helper } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/image_scaler.h b/onnxruntime/contrib_ops/cpu/image_scaler.h index 9e9d9908ab188..865bca51f1e85 100644 --- a/onnxruntime/contrib_ops/cpu/image_scaler.h +++ b/onnxruntime/contrib_ops/cpu/image_scaler.h @@ -16,8 +16,8 @@ template class ImageScaler final : public OpKernel { public: ImageScaler(const OpKernelInfo& info) : OpKernel(info) { - ORT_ENFORCE(info.GetAttr("scale", &scale_).IsOK()); - ORT_ENFORCE(info.GetAttrs("bias", bias_).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttr("scale", &scale_)); + ORT_THROW_IF_ERROR(info.GetAttrs("bias", bias_)); } Status Compute(OpKernelContext* context) const override { diff --git a/onnxruntime/contrib_ops/cpu/math/sparse_dense_matmul.cc b/onnxruntime/contrib_ops/cpu/math/sparse_dense_matmul.cc index b00b10ad649b1..46a8b70d289b7 100644 --- a/onnxruntime/contrib_ops/cpu/math/sparse_dense_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/math/sparse_dense_matmul.cc @@ -47,7 +47,6 @@ struct ComputeCtx { float alpha; }; -#if !defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__) template inline void SparseDenseMatMulImpl(const ComputeCtx& ctx, const ConstSparseMatrixMap& map_A, const ConstEigenMatrixMapRowMajor& map_B, EigenMatrixMapRowMajor& output_map) { @@ -64,7 +63,8 @@ inline void SparseDenseMatMulImpl(const ComputeCtx& ctx, const ConstSparseMatrix template <> inline void SparseDenseMatMulImpl(const ComputeCtx& ctx, const ConstSparseMatrixMap& map_A, - const ConstEigenMatrixMapRowMajor& map_B, EigenMatrixMapRowMajor& output_map) { + const ConstEigenMatrixMapRowMajor& map_B, + EigenMatrixMapRowMajor& output_map) { if (ctx.trans_A && ctx.trans_B) { output_map = map_A.transpose() * ctx.alpha * map_B.transpose(); } else if (ctx.trans_A && !ctx.trans_B) { @@ -84,21 +84,47 @@ struct SparseToDenseCsr { const auto& b_dims = B.Shape().GetDims(); const auto& out_dims = output.Shape().GetDims(); auto csr_view = A.AsCsr(); - - ConstSparseMatrixMap map_A(a_dims[0], a_dims[1], A.NumValues(), - csr_view.Outer().Data(), - csr_view.Inner().Data(), + const Eigen::Index* inner_index_pointer = nullptr; + const Eigen::Index* outer_index_pointer = nullptr; + // For auto-release the above two pointers when they are not NULL. + std::unique_ptr buffer_holder_inner, buffer_holder_outer; + if constexpr (std::is_integral::value && + std::is_signed::value && + (sizeof(Eigen::Index) == sizeof(int64_t))) { + // On macOS the following reinterpret_cast is necessary because Eigen::Index is an alias of `long` but int64_t is + // `long long`. Though they have the same size, compilers still do not allow an implicit casting between them. + inner_index_pointer = reinterpret_cast(csr_view.Inner().Data()); + outer_index_pointer = reinterpret_cast(csr_view.Outer().Data()); + } else { + // In a 32-bit build we need to cast the following two tensors to 32 bits + gsl::span inner_data = csr_view.Inner().DataAsSpan(); + gsl::span outer_data = csr_view.Outer().DataAsSpan(); + buffer_holder_inner.reset(new Eigen::Index[inner_data.size()]); + buffer_holder_outer.reset(new Eigen::Index[outer_data.size()]); + inner_index_pointer = buffer_holder_inner.get(); + outer_index_pointer = buffer_holder_outer.get(); + + std::transform(inner_data.begin(), inner_data.end(), + buffer_holder_inner.get(), [](int64_t v) -> Eigen::Index { + return narrow(v); + }); + std::transform(outer_data.begin(), outer_data.end(), + buffer_holder_outer.get(), [](int64_t v) -> Eigen::Index { + return narrow(v); + }); + } + ConstSparseMatrixMap map_A(narrow(a_dims[0]), narrow(a_dims[1]), + narrow(A.NumValues()), outer_index_pointer, inner_index_pointer, A.Values().Data()); - ConstEigenMatrixMapRowMajor map_B(B.Data(), b_dims[0], b_dims[1]); - EigenMatrixMapRowMajor output_map(output.MutableData(), out_dims[0], out_dims[1]); + ConstEigenMatrixMapRowMajor map_B(B.Data(), narrow(b_dims[0]), narrow(b_dims[1])); + EigenMatrixMapRowMajor output_map(output.MutableData(), narrow(out_dims[0]), + narrow(out_dims[1])); // XXX: Consider re-writing it as a parallel loop as Eigen requires it to use OpenMP // XXX: Consider vectorization SparseDenseMatMulImpl(ctx, map_A, map_B, output_map); } }; -#endif //! defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__) - template inline T Mul(T a_value, float, T b_value) { return a_value * b_value; @@ -121,9 +147,11 @@ struct SparseToDenseCoo { auto coo_view = A.AsCoo(); const auto& ind_dims = coo_view.Indices().Shape().GetDims(); ORT_RETURN_IF_NOT(ind_dims.size() == 2, "COO indices must be 2-D, got: ", ind_dims.size()); - ConstEigenMatrixMapRowMajor a_indicies_map(coo_view.Indices().Data(), narrow(ind_dims[0]), narrow(ind_dims[1])); + ConstEigenMatrixMapRowMajor a_indicies_map(coo_view.Indices().Data(), narrow(ind_dims[0]), + narrow(ind_dims[1])); ConstEigenMatrixMapRowMajor map_b(B.Data(), narrow(b_dims[0]), narrow(b_dims[1])); - EigenMatrixMapRowMajor output_map(output.MutableData(), narrow(out_dims[0]), narrow(out_dims[1])); + EigenMatrixMapRowMajor output_map(output.MutableData(), narrow(out_dims[0]), + narrow(out_dims[1])); output_map.setZero(); const auto rhs_right = (ctx.trans_B) ? b_dims[0] : b_dims[1]; @@ -140,7 +168,8 @@ struct SparseToDenseCoo { ORT_RETURN_IF_NOT(m < out_left, "COO m index: ", m, " is out of bounds of out_left: ", out_left); const T a_value = a_values[i]; for (int64_t n = 0; n < rhs_right; ++n) { - const T b_value = (ctx.trans_B) ? map_b(narrow(n), narrow(k)) : map_b(narrow(k), narrow(n)); + const T b_value = + (ctx.trans_B) ? map_b(narrow(n), narrow(k)) : map_b(narrow(k), narrow(n)); output_map(narrow(m), narrow(n)) += Mul(a_value, ctx.alpha, b_value); } } @@ -170,8 +199,9 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const { const auto inner_B = (trans_b_attr_) ? b_dims[1] : b_dims[0]; const auto outer_B = (trans_b_attr_) ? b_dims[0] : b_dims[1]; - ORT_RETURN_IF_NOT(inner_A == inner_B, "Can not multiply A and B as inner dimension does not match. inner_A: ", - inner_A, " vs inner_B: ", inner_B); + ORT_RETURN_IF_NOT(inner_A == inner_B, + "Can not multiply A and B as inner dimension does not match. inner_A: ", inner_A, + " vs inner_B: ", inner_B); TensorShape output_shape{outer_A, outer_B}; auto* output = ctx->Output(0, output_shape); @@ -184,12 +214,10 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const { auto coo_view = A->AsCoo(); const auto num_dims = coo_view.Indices().Shape().NumDimensions(); ORT_RETURN_IF_NOT(num_dims == 2, "Expecting COO 2-D indices shape"); - ORT_RETURN_IF_NOT(A->Values().Shape().Size() * 2 == coo_view.Indices().Shape().Size(), "Expecting 2xValues == indices"); + ORT_RETURN_IF_NOT(A->Values().Shape().Size() * 2 == coo_view.Indices().Shape().Size(), + "Expecting 2xValues == indices"); auto status = t_disp.InvokeRet(compute_ctx, *A, *B, *output); ORT_RETURN_IF_ERROR(status); -// Eigen has a bug in x86 where it calculates reallocation size as -1 -// and throws bad_alloc -#if !defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__) } else if (A->Format() == SparseFormat::kCsrc) { auto csr_view = A->AsCsr(); ORT_RETURN_IF_NOT(A->Values().Shape().Size() == csr_view.Inner().Shape().Size(), @@ -199,11 +227,6 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const { } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Currently support only COO and CSR(x64) formats"); } -#else - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "WASM and 32-bit builds support only COO format"); - } -#endif //! defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__) return Status::OK(); } @@ -211,4 +234,4 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const { } // namespace contrib } // namespace onnxruntime -#endif //! defined(DISABLE_SPARSE_TENSORS) \ No newline at end of file +#endif //! defined(DISABLE_SPARSE_TENSORS) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 320a05bb97dac..b060d500c6484 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -20,30 +20,158 @@ class MatMulNBits final : public OpKernel { K_{narrow(info.GetAttr("K"))}, N_{narrow(info.GetAttr("N"))}, block_size_{narrow(info.GetAttr("block_size"))}, - nbits_{narrow(info.GetAttr("bits"))} { + nbits_{narrow(info.GetAttr("bits"))}, + accuracy_level_{info.GetAttr("accuracy_level")} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + is_asym_ = info.GetInputCount() >= 4; + const Tensor* tensor_B = nullptr; + const Tensor* tensor_scale = nullptr; + const Tensor* tensor_zero_point = nullptr; + bool B_constant = info.TryGetConstantInput(1, &tensor_B); + bool scale_constant = info.TryGetConstantInput(2, &tensor_scale); + bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point); + all_constant_ = B_constant && scale_constant; + all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_; } Status Compute(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, + /*out*/ bool& used_shared_buffers) override; + private: const size_t K_; const size_t N_; const size_t block_size_; const size_t nbits_; + const int64_t accuracy_level_; const bool column_wise_quant_{true}; + IAllocatorUniquePtr packed_b_; + size_t packed_b_size_{0}; + bool is_asym_{false}; + bool all_constant_{false}; }; +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + if (!all_constant_) { + return Status::OK(); + } + auto compt_type = static_cast(accuracy_level_); + MLAS_THREADPOOL* pool = NULL; + if (input_idx == 1) { + packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type); + if (packed_b_size_ == 0) return Status::OK(); + auto qptr = tensor.Data(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + if (packed_b_ == nullptr) { + return Status::OK(); + } + std::memset(packed_b_.get(), 0, packed_b_size_); + MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, false, compt_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + if (input_idx == 2 && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, !is_asym_, compt_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + if (input_idx == 3 && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, is_asym_, compt_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + + return Status::OK(); +} + +Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + // Pack three tensors into one buffer + if (input_idx == 1) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + if (input_idx == 2) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + if (input_idx == 3) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + return Status::OK(); +} + Status MatMulNBits::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); const Tensor* a = ctx->Input(0); + const auto* a_data = a->Data(); + + if (packed_b_.get()) { + TensorShape b_shape({static_cast(N_), static_cast(K_)}); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + std::vector gemm_params(max_len); + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + for (size_t i = 0; i < max_len; i++) { + gemm_params[i].A = a_data + helper.LeftOffsets()[i]; + gemm_params[i].lda = lda; + gemm_params[i].B = packed_b_.get(); + gemm_params[i].C = y_data + helper.OutputOffsets()[i]; + gemm_params[i].ldc = N; + } + auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); + // workspace for activation process(dynamic quantization and others) + auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); + MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), + thread_pool); + return Status::OK(); + } + const Tensor* b = ctx->Input(1); const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); - - const auto* a_data = a->Data(); const uint8_t* b_data = b->Data(); const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 89e2351428d40..cbe536c6ce45a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -69,6 +69,7 @@ struct Flash_fwd_params : public Qkv_params { int seqlen_q_rounded = 0; int seqlen_k_rounded = 0; int d_rounded = 0; + int rotary_dim = 0; // The scaling factors for the kernel. float scale_softmax = 0.0; @@ -92,12 +93,26 @@ struct Flash_fwd_params : public Qkv_params { index_t knew_head_stride = 0; index_t vnew_head_stride = 0; + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr = nullptr; + void* __restrict__ rotary_sin_ptr = nullptr; + + // The indices to index into the KV cache. + int* __restrict__ cache_batch_idx = nullptr; + + // Local window size + int window_size_left = -1; + int window_size_right = -1; + bool is_bf16 = false; bool is_causal = false; // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. bool is_seqlens_k_cumulative = true; + + bool is_rotary_interleaved = false; + int num_splits = 0; // For split-KV version const cudaDeviceProp* dprops = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 89a27c4d2b0d3..76190aad68fdb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -35,7 +35,9 @@ void set_params_fprop(Flash_fwd_params& params, void* softmax_lse_d, float softmax_scale, bool is_causal, - bool kv_bsnh = true) { + bool kv_bsnh = true, + int window_size_left = -1, + int window_size_right = -1) { // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; @@ -102,7 +104,21 @@ void set_params_fprop(Flash_fwd_params& params, params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates + // local and causal, meaning when we have local window size params.is_causal = is_causal; + if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { + params.is_causal = false; + } + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.is_seqlens_k_cumulative = true; } @@ -227,7 +243,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - bool kv_bsnh) { + bool kv_bsnh, + int local_window_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -247,7 +264,9 @@ Status mha_fwd(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, - kv_bsnh); + kv_bsnh, + local_window_size, + is_causal ? 0 : -1); params.dprops = &dprops; params.knew_ptr = nullptr; params.vnew_ptr = nullptr; @@ -306,7 +325,10 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, nullptr, softmax_lse, softmax_scale, - is_causal); + is_causal, + true, + -1, + is_causal ? 0 : -1); params.dprops = &dprops; params.num_splits = 0; params.softmax_lseaccum_ptr = nullptr; @@ -347,11 +369,11 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, bool past_bsnh, // otherwise bnsh int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads - void* out_accum // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded -) { - if (seqlen_q == 1) { - is_causal = false; - } // causal=true is the same as causal=false in this case + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size) { + // if (seqlen_q == 1) { + // is_causal = false; + // } // causal=true is the same as causal=false in this case auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); @@ -372,7 +394,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, - past_bsnh); + past_bsnh, + local_window_size, + is_causal ? 0 : -1); params.dprops = &dprops; if (k != nullptr && v != nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 58f4304251872..efc1f565c4fa0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -54,7 +54,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - bool kv_bsnh = true); + bool kv_bsnh = true, + int local_window_size = -1); Status mha_varlen_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, @@ -96,8 +97,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, bool past_bsnh, // otherwise bnsh int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads - void* out_accum = nullptr // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded -); + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size = -1); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index eb1c794d6df54..028233f66850f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -29,47 +29,6 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTE_HOST_DEVICE auto -make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(cute::Layout, cute::Int>, - cute::Stride<_1, cute::Int>>{}, - make_layout(cute::size<2>(TileShape_MNK{}))); - - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE auto -make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(cute::Layout, cute::Int>, - cute::Stride<_1, cute::Int>>{}, - // TODO: Shouldn't this be size<1>? - make_layout(cute::size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, Tensor2& acc_o, float softmax_scale_log2) { @@ -123,7 +82,7 @@ inline __device__ void write_softmax_to_gmem( //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -144,12 +103,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); // We exit early and write 0 to gO and gLSE. // Otherwise we might read OOB elements from gK and gV. - if (n_block_max <= 0) { + if (n_block_max <= n_block_min) { const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), @@ -197,7 +158,6 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), cute::Shape, cute::Int>{}, make_stride(params.q_row_stride, _1{})); @@ -332,9 +292,9 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -364,22 +324,22 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); } flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -390,8 +350,8 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // Convert scores from fp32 to fp16/bf16 cute::Tensor rP = flash::convert_type(scores); @@ -408,14 +368,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= 0) { + if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S - for (; n_block >= 0; --n_block) { + for (; n_block >= n_block_min; --n_block) { cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); @@ -431,7 +391,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -441,8 +401,15 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi } // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); cute::Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -543,7 +510,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -572,11 +539,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; - const int n_block_min = n_split_idx * n_blocks_per_split; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); - if (Is_causal) { + if (Is_causal || Is_local) { n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. @@ -626,10 +595,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -641,16 +609,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); - // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, - // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. - // This maps to accessing the first 64 rows of knew_ptr. - Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), - Shape, Int>{}, - make_stride(params.knew_row_stride, _1{})); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } - Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), - Shape, Int>{}, - make_stride(params.vnew_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); @@ -664,11 +622,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); typename Kernel_traits::TiledMma tiled_mma; @@ -732,17 +688,129 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons } // Prologue + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + } + } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + if (n_block_max > n_block_copy_min) { + tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; + tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; + } + } + // Read Q from gmem to smem, optionally apply rotary embedding. Tensor tQrQ = make_fragment_like(tQgQ); - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } + } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // flash::cp_async_wait<0>(); @@ -760,9 +828,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -770,32 +838,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (cute::thread0()) { print(tKgK); } - // if (cute::thread0()) { print(tKsK); } - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - // __syncthreads(); - // if (cute::thread0()) { print(tKgK); } - // __syncthreads(); - } - // Advance gV if (masking_step > 0) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - if (Append_KV) { - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); } cute::cp_async_fence(); @@ -810,15 +860,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); } flash::cp_async_wait<0>(); @@ -826,26 +876,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } // __syncthreads(); - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); } - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } - if (n_block > n_block_min) { // Advance gK - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); } tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - if (Append_KV) { - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); - } - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, - binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -853,8 +887,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // Convert scores from fp32 to fp16/bf16 @@ -879,20 +913,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } // Advance gV tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - if (Append_KV) { - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( @@ -901,22 +924,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - if (Append_KV) { - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, - binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -924,7 +935,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -1031,7 +1049,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1047,12 +1065,12 @@ inline __device__ void compute_attn(const Params& params) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1061,24 +1079,23 @@ inline __device__ void compute_attn_splitkv(const Params& params) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void combine_attn_seqk_parallel(const Params& params) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; constexpr int kMaxSplits = 1 << Log_max_splits; - constexpr int kBlockM = 16; constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); - // static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer"); - static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32"); - static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); // Shared memory. // kBlockM + 1 instead of kBlockM to reduce bank conflicts. @@ -1094,10 +1111,10 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { make_stride(params.b * params.h * params.seqlen_q, _1{})); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); - constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads; + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // Read the LSE values from gmem and store them in shared memory, then tranpose them. - constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM; + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; @@ -1165,7 +1182,12 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), Shape, Int>{}, Stride, _1>{}); - typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); Tensor tOrO = make_tensor(shape(tOgOaccum)); @@ -1183,8 +1205,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } } -// Load Oaccum in then scale and accumulate to O -#pragma unroll 2 + // Load Oaccum in then scale and accumulate to O for (int split = 0; split < params.num_splits; ++split) { flash::copy( gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index 82dfa59b8f8e7..87d189a803f8a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -10,29 +10,30 @@ namespace onnxruntime { namespace flash { -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn(params); + flash::compute_attn(params); #else (void)params; #endif } -template +template __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn_splitkv(params); + flash::compute_attn_splitkv(params); #else (void)params; #endif } -template +template __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static_assert(Log_max_splits >= 1); - flash::combine_attn_seqk_parallel(params); + flash::combine_attn_seqk_parallel(params); #else (void)params; #endif @@ -52,20 +53,25 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - // ORT_ENFORCE(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); }); }); } @@ -82,40 +88,46 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(params.num_splits > 1, Split, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); - auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV > ; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - kernel<<>>(params); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); }); }); }); }); }); if (params.num_splits > 1) { - dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16); + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } }); } @@ -130,7 +142,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 32; + constexpr static int Headdim = 32; BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_causal>(params, stream); }); @@ -138,7 +150,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 64; + constexpr static int Headdim = 64; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using block size (64 x 256) is 27% slower for seqlen=2k @@ -174,8 +186,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 128; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 128; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. @@ -201,8 +213,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 160; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 160; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, H100, 128 x 32 is the fastest. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -241,12 +253,11 @@ void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t Headdim = 224; - constexpr size_t threshold = 2 * Headdim * (128 + 2 * 64); - size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + constexpr static int Headdim = 224; + int max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if (max_smem_per_block >= threshold) { // 112 KB + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); @@ -262,16 +273,14 @@ void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t Headdim = 256; - constexpr size_t min_threshold = 2 * Headdim * (128 + 2 * 64); - constexpr size_t max_threshold = 4 * Headdim * (64 + 2 * 64); + constexpr static int Headdim = 256; size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= min_threshold && max_smem_per_sm < max_threshold) { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h index 134f159e258c4..1c0ed7f2fc2e8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h @@ -161,7 +161,14 @@ struct Flash_fwd_kernel_traits : public Base { cute::Stride<_16, _1>>>; using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, - cute::Layout>{})); // Val layout, 4 vals per store + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index 842edf3a98a86..8017f83bbb01d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor& tensor, const int max_ } } -template -inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset_, - const int max_seqlen_q, const int warp_row_stride) { +template +inline __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; @@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } @@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i } } +template +inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, + max_seqlen_q, warp_row_stride, -1, 0); +} + template inline __device__ void apply_mask_causal_w_idx( Tensor& tensor, Tensor const& idx_rowcol, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h index 02042e183f808..271112c5e890a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -307,7 +307,7 @@ template inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S, Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, int max_MN = 0) { + Tensor const& predicate_K, const int max_MN = 0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA @@ -334,65 +334,161 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor const //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor const& S0, - Tensor const& S1, +inline __device__ void copy_w_min_idx(Tensor const& S, Tensor& D, Tensor const& identity_MN, Tensor const& predicate_K, - const int max_MN = 0, const int row_idx_switch = 0) { - CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{}); + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K - // There's no case where !Clear_OOB_K && Clear_OOB_MN - static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); -// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); } -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); } + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } #pragma unroll - for (int m = 0; m < size<1>(S0); ++m) { - auto& S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1; - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } #pragma unroll - for (int k = 0; k < size<2>(S0); ++k) { + for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { - cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_interleaved(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_w_min_idx(Tensor const& S, - Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, - const int max_MN = 0, const int min_MN = 0) { +inline __device__ void copy_rotary_contiguous(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } #pragma unroll for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || predicate_K(k)) { - cute::copy(S(_, m, k), D(_, m, k)); + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); } } } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index f21dff08e0350..93892169f6c79 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -44,9 +44,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); num_heads_ = static_cast(num_heads); kv_num_heads_ = static_cast(kv_num_heads); - is_unidirectional_ = true; - // left_padding_ = info.GetAttrOrDefault("left_padding_last_token", 0) == 1; is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION @@ -92,8 +91,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { is_past_bsnh_, scale_, device_prop.maxThreadsPerBlock)); - parameters.is_unidirectional = is_unidirectional_; - // parameters.left_padding = left_padding_; + parameters.local_window_size = local_window_size_; int sequence_length = parameters.sequence_length; TensorShapeVector output_shape(3); @@ -139,6 +137,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { bool use_memory_efficient_attention = !use_flash_attention && !disable_memory_efficient_attention_ && + local_window_size_ == -1 && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -222,6 +221,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.k = reinterpret_cast(k_buffer.get()); data.v = reinterpret_cast(v_buffer.get()); } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index aade0436dc141..54a8127e29e7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -22,8 +22,7 @@ class GroupQueryAttention final : public CudaKernel { protected: int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention - // bool left_padding_; // shifts last token to end of buffer - bool is_unidirectional_; // causal + int local_window_size_; bool is_past_bsnh_; float scale_; bool disable_flash_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 2d158155eeba9..b22ccb68c1e7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -468,55 +468,6 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i return CUDA_CALL(cudaGetLastError()); } -// // Kernel to append new kv to kv buffer in place -// template -// __global__ void LeftPadLast(const int max_seqlen, -// T* kv_buff, -// const int* seqlens_k) { // refers to kv buff; otherwise bnsh -// const int h = threadIdx.x; -// const int n = blockIdx.x; -// const int b = blockIdx.y; - -// const int num_heads = gridDim.x; -// const int H = blockDim.x; - -// const int present_batch_stride = max_seqlen * num_heads * H; -// const int present_row_stride = num_heads * H; -// const int present_head_stride = H; - -// // kv_buff: BTNH or BNTH with buffered memory for new -// // new_kv: BLNH - -// const int s = seqlens_k[b]; - -// const int in_offset = b * present_batch_stride + s * present_row_stride + n * present_head_stride + h; -// const int out_offset = b * present_batch_stride + (max_seqlen - 1) * present_row_stride + n * present_head_stride + h; -// kv_buff[out_offset] = kv_buff[in_offset]; -// } - -// // Concat new to kv buffer in place -// template -// Status LaunchLeftPadLast(contrib::GroupQueryAttentionParameters& parameters, -// GroupQueryAttentionData& data, -// cudaStream_t stream, -// const int max_threads_per_block) { -// const int batch_size = parameters.batch_size; -// const int sequence_length = parameters.sequence_length; -// const int num_heads = parameters.num_heads; -// const int head_size = parameters.head_size; - -// // Indicates past sequence_length of each sequence -// const int* seqlens_k = reinterpret_cast(data.seqlens_k); - -// const int H = head_size / 4; -// const dim3 grid(num_heads, batch_size, 1); -// const dim3 block(H, 1, 1); -// LeftPadLast<<>>(sequence_length, -// reinterpret_cast(data.output), -// seqlens_k); -// return CUDA_CALL(cudaGetLastError()); -// } - ////////// Launch Kernels #if USE_FLASH_ATTENTION @@ -541,7 +492,7 @@ Status FlashAttention( void* key = reinterpret_cast(const_cast(data.key)); void* value = reinterpret_cast(const_cast(data.value)); - bool is_causal = parameters.is_unidirectional; + bool is_causal = true; // Note: seqlens_k is past sequence length for flash if (parameters.is_prompt) { @@ -579,7 +530,7 @@ Status FlashAttention( seqlens_k, batch_size, num_heads, kv_num_heads, head_size, sequence_length, present_sequence_length, kv_sequence_length, scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); + reinterpret_cast(data.out_accum), parameters.local_window_size)); } else { // Not share buffer case // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient @@ -611,13 +562,9 @@ Status FlashAttention( seqlens_k, batch_size, num_heads, kv_num_heads, head_size, sequence_length, present_sequence_length, 0, scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); + reinterpret_cast(data.out_accum), parameters.local_window_size)); } - // if (parameters.left_padding && parameters.is_prompt) { - // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); - // } - DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); @@ -704,9 +651,11 @@ Status EfficientAttention( p.max_sequence_length = present_sequence_length; p.qk_head_size = head_size; p.v_head_size = head_size; - p.causal = parameters.is_unidirectional; + p.causal = true; p.scale = scale; p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; p.query = query; p.key = key; p.value = value; @@ -721,10 +670,6 @@ Status EfficientAttention( p.has_custom_right_padding = true; run_memory_efficient_attention(p); - // if (parameters.left_padding && parameters.is_prompt) { - // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); - // } - DUMP_TENSOR_INIT(); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index b4b5dac1fbe19..2d12e975d88d7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -74,7 +74,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { parameters.max_sequence_length, parameters.position_ids_format, interleaved, - device_prop.maxThreadsPerBlock); + device_prop.maxThreadsPerBlock, + parameters.transposed); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index c54e72dcfce13..e1b83bd8caf54 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -27,7 +27,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int num_heads, const int head_size, const int position_ids_format, - const bool interleaved) { + const bool interleaved, + const int batch_stride, + const int seq_stride, + const int head_stride) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently @@ -37,11 +40,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int i = threadIdx.x; - const int block_offset = b * sequence_length * num_heads + s * num_heads + n; - const int data_offset = block_offset * head_size; + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - const T* input_data = input + data_offset; - T* output_data = output + data_offset; + const T* input_data = input + block_offset; + T* output_data = output + block_offset; // Cache is (M, H/2) const int half_head_size = head_size / 2; @@ -83,7 +85,8 @@ Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block) { + const int max_threads_per_block, + const bool transposed) { constexpr int smem_size = 0; const dim3 grid(num_heads, sequence_length, batch_size); @@ -94,10 +97,22 @@ Status LaunchRotaryEmbeddingKernel( // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. + // Default input tensor shape is [batch, seq, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (transposed) { + // When transposed, input tensor shape is [batch, num_heads, seq, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } + assert(head_size <= max_threads_per_block); RotaryEmbeddingBSNH<<>>( output, input, cos_cache, sin_cache, position_ids, - sequence_length, num_heads, head_size, position_ids_format, interleaved + sequence_length, num_heads, head_size, position_ids_format, interleaved, + batch_stride, seq_stride, head_stride ); return CUDA_CALL(cudaGetLastError()); @@ -117,7 +132,8 @@ template Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); template Status LaunchRotaryEmbeddingKernel( cudaStream_t stream, @@ -133,7 +149,8 @@ template Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index 29ff48a8ad0fb..ee1ccc43dcbff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -24,7 +24,8 @@ Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/transformer_cuda_common.h b/onnxruntime/contrib_ops/cuda/bert/transformer_cuda_common.h index faf9310c4c3fd..a0da24210459c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/transformer_cuda_common.h +++ b/onnxruntime/contrib_ops/cuda/bert/transformer_cuda_common.h @@ -3,7 +3,7 @@ #pragma once -#include "core/providers/cuda/cuda_common.h" +#include namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc index 574a3133de815..0f42363bca22d 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc @@ -24,9 +24,7 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define NCCL_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(NCCL_CALL(expr)) - -static ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) { +ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) { if (type == DataTypeImpl::GetType()) { return ncclUint8; } else if (type == DataTypeImpl::GetType()) { diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h index 7fc26e6be57b9..9ea61f2bd952d 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h @@ -7,17 +7,21 @@ #if defined(ORT_USE_NCCL) #include -#include #include -#include +#include #include #include +#include #endif namespace onnxruntime { namespace contrib { namespace cuda { +#define NCCL_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(NCCL_CALL(expr)) + +ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type); + // ----------------------------------------------------------------------- // Defines a new version of nccl classes // that independent with training::DistributedRunContext, only rely on MPI diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc new file mode 100644 index 0000000000000..40a667ffd5d83 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/bert/transformer_cuda_common.h" +#include "sharded_moe.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ShardedMoE, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(0, 0) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ShardedMoE); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +using namespace ONNX_NAMESPACE; + +template +ShardedMoE::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("local_experts_start_index", &local_experts_start_index_).IsOK()); + rank_to_experts_start_index_.resize(nccl_->Size()); + // Initialize rank_to_experts_start_index_[0] to a value to convey that it is not initialized. + rank_to_experts_start_index_[0] = std::numeric_limits::min(); +} + +template +Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaT; + auto stream = context->GetComputeStream(); + + auto& device_prop = GetDeviceProp(); + const int sm = device_prop.major * 10 + device_prop.minor; + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + // Create a {Rank, ExpertsStartIndex} map on Host. + AutoDestoryCudaEvent cuda_event; + cudaEvent_t& copy_event = cuda_event.Get(); + ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event)); + + const Tensor* input = context->Input(0); + const Tensor* router_probs = context->Input(1); + const Tensor* fc1_experts_weights = context->Input(2); + const Tensor* fc2_experts_weights = context->Input(3); + const Tensor* fc1_experts_bias_optional = context->Input(4); + const Tensor* fc2_experts_bias_optional = context->Input(5); + + MoEParameters moe_params; + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights, + fc1_experts_bias_optional, fc2_experts_bias_optional)); + ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, + "num_experts should be divisible by world_size"); + + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm); + + size_t ws_size = + moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(k_)); + + size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); + size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); + size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); + size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); + + // TODO: allocate one buffer and reuse it. + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); + IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); + IAllocatorUniquePtr fc2_output_bc = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); + IAllocatorUniquePtr expert_scales = + IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); + IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = + IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream); + IAllocatorUniquePtr expert_for_source_row = + IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); + + // fc1_scales and fc2_scales are used in quantized MoE + const CudaT* fc1_scales_ptr = nullptr; + const CudaT* fc2_scales_ptr = nullptr; + + moe_runner.run_moe_fc(reinterpret_cast(input->template Data()), + reinterpret_cast(router_probs->template Data()), + reinterpret_cast(fc1_experts_weights->template Data()), + std::move(fc1_scales_ptr), + fc1_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc1_experts_bias_optional->template Data()), + activation_type_, reinterpret_cast(fc2_experts_weights->template Data()), + std::move(fc2_scales_ptr), static_cast(moe_params.num_rows), + static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(moe_params.local_num_experts), static_cast(local_experts_start_index_), + static_cast(k_), reinterpret_cast(work_space.get()), + reinterpret_cast(fc2_output.get()), reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), Stream(context)); + + Tensor* output = context->Output(0, input->Shape()); + + size_t stride_count = moe_params.hidden_size; + size_t stride_bytes = stride_count * sizeof(CudaT); + int64_t total_past_rows = 0; + int64_t total_covered_rows = 0; + if (copy_event != nullptr) { + CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event)); + } + NCCL_RETURN_IF_ERROR(ncclGroupStart()); + for (int rank = 0; rank < nccl_->Size(); ++rank) { + int64_t experts_start_index = rank_to_experts_start_index_[rank]; + moe_runner.get_total_rows_info(experts_start_index, + moe_params.local_num_experts, + total_past_rows, + total_covered_rows); + const char* src = reinterpret_cast(fc2_output.get()) + total_past_rows * stride_bytes; + char* dst = reinterpret_cast(fc2_output_bc.get()) + total_past_rows * stride_bytes; + NCCL_RETURN_IF_ERROR(ncclBroadcast(src, + dst, + total_covered_rows * stride_count, + GetNcclDataType(input->DataType()), + rank, + nccl_->Comm(), + Stream(context))); + } + NCCL_RETURN_IF_ERROR(ncclGroupEnd()); + + ort_fastertransformer::finalize_moe_routing_kernelLauncher( + reinterpret_cast(fc2_output_bc.get()), reinterpret_cast(output->template MutableData()), + fc2_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc2_experts_bias_optional->template Data()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), static_cast(moe_params.num_rows), + static_cast(moe_params.hidden_size), static_cast(k_), Stream(context)); + + return Status::OK(); +} + +template +Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, + OpKernelContext* context, + cudaEvent_t& cuda_event) const { + if (rank_to_experts_start_index_[0] != std::numeric_limits::min()) { + return Status::OK(); + } + + auto stream = context->GetComputeStream(); + + using IndexType = int64_t; + size_t IndexTypeSize = sizeof(IndexType); + + IAllocatorUniquePtr experts_start_index_d = + IAllocator::MakeUniquePtr(allocator, 1, false, stream); + IAllocatorUniquePtr rank_to_experts_start_index_d = + IAllocator::MakeUniquePtr(allocator, nccl_->Size(), false, stream); + + // Only happens in the first run. + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(), + &local_experts_start_index_, + IndexTypeSize, + cudaMemcpyHostToDevice, + Stream(context))); + NCCL_RETURN_IF_ERROR(ncclAllGather(reinterpret_cast(experts_start_index_d.get()), + reinterpret_cast(rank_to_experts_start_index_d.get()), + 1, + GetNcclDataType(DataTypeImpl::GetType()), + nccl_->Comm(), + Stream(context))); + // The const_cast<> violates the const modifier to make sure the synchronization happens only once per session. + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast(rank_to_experts_start_index_.data()), + rank_to_experts_start_index_d.get(), + nccl_->Size() * IndexTypeSize, + cudaMemcpyDeviceToHost, + Stream(context))); + + CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(&cuda_event, cudaEventDisableTiming)); + CUDA_RETURN_IF_ERROR(cudaEventRecord(cuda_event, Stream(context))); + + return Status::OK(); +} +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h new file mode 100644 index 0000000000000..5ea4ae59c4020 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" +#include "contrib_ops/cuda/moe/moe_base.h" +#include "core/common/common.h" +#include "nccl_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +using namespace onnxruntime::cuda; + +template +class ShardedMoE final : public NcclKernel, public MoEBase { + public: + explicit ShardedMoE(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + Status SynchronizeExpertsStartIndex(AllocatorPtr& alloc, OpKernelContext* ctx, cudaEvent_t& cuda_event) const; + + int64_t local_experts_start_index_; + std::vector rank_to_experts_start_index_; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.cc b/onnxruntime/contrib_ops/cuda/collective/sharding.cc index b6b509023a1a9..1b4cc4502cff8 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.cc @@ -244,7 +244,7 @@ DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info // stored on a 1-D mesh with 2 devices and the second input on another 1-D // mesh with 1 device. std::vector attr_input_device_mesh_shapes; - ORT_ENFORCE(info.GetAttrs("input_device_mesh_shapes", attr_input_device_mesh_shapes).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttrs("input_device_mesh_shapes", attr_input_device_mesh_shapes)); // input_device_mesh_elements[i] is the flattened device mesh for the i-th input. // Note that its actual shape is input_device_mesh_shapes[i]. @@ -255,12 +255,12 @@ DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info // Then the first input is stored on a 1-D mesh with 2 devices and the second // input on another 1-D mesh with 1 device. std::vector attr_input_device_mesh_elements; - ORT_ENFORCE(info.GetAttrs("input_device_mesh_elements", attr_input_device_mesh_elements).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttrs("input_device_mesh_elements", attr_input_device_mesh_elements)); // input_shard_specs[i] is the sharding spec of the i-th input; e.g., // "RR" if the i-th input is not sharded. std::vector input_shard_specs; - ORT_ENFORCE(info.GetAttrs("input_shard_specs", input_shard_specs).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttrs("input_shard_specs", input_shard_specs)); ORT_ENFORCE(attr_input_device_mesh_shapes.size() == attr_input_device_mesh_elements.size()); ORT_ENFORCE(attr_input_device_mesh_shapes.size() == input_shard_specs.size()); @@ -274,13 +274,13 @@ DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info } std::vector attr_output_device_mesh_shapes; - ORT_ENFORCE(info.GetAttrs("output_device_mesh_shapes", attr_output_device_mesh_shapes).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttrs("output_device_mesh_shapes", attr_output_device_mesh_shapes)); std::vector attr_output_device_mesh_elements; - ORT_ENFORCE(info.GetAttrs("output_device_mesh_elements", attr_output_device_mesh_elements).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttrs("output_device_mesh_elements", attr_output_device_mesh_elements)); std::vector output_shard_specs; - ORT_ENFORCE(info.GetAttrs("output_shard_specs", output_shard_specs).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttrs("output_shard_specs", output_shard_specs)); ORT_ENFORCE(attr_output_device_mesh_shapes.size() == attr_output_device_mesh_elements.size()); ORT_ENFORCE(attr_output_device_mesh_shapes.size() == output_shard_specs.size()); diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 7172a28316f16..7875ac75b8188 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MatMulBnb4); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); @@ -164,6 +165,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); @@ -313,6 +317,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -362,6 +367,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc index 251850f621361..6cdccdb1becb1 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc @@ -14,17 +14,23 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL() \ - ONNX_OPERATOR_KERNEL_EX( \ - GemmFloat8, \ - kMSDomain, \ - 1, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("TA", BuildKernelDefConstraints()) \ - .TypeConstraint("TB", BuildKernelDefConstraints()) \ - .TypeConstraint("TR", BuildKernelDefConstraints()) \ - .TypeConstraint("TS", BuildKernelDefConstraints()), \ +#if !defined(DISABLE_FLOAT8_TYPES) +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#else +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#endif + +#define REGISTER_KERNEL() \ + ONNX_OPERATOR_KERNEL_EX( \ + GemmFloat8, \ + kMSDomain, \ + 1, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) \ + .TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) \ + .TypeConstraint("TR", GEMM_FLOAT8_CONSTRAINTS) \ + .TypeConstraint("TS", BuildKernelDefConstraints()), \ GemmFloat8); REGISTER_KERNEL() @@ -38,7 +44,7 @@ GemmFloat8::GemmFloat8(const OpKernelInfo& info) : CudaKernel(info) { alpha_ = info.GetAttrOrDefault("alpha", 1); beta_ = info.GetAttrOrDefault("beta", 0); -#if (CUDA_VERSION <= 12000) +#if (CUDA_VERSION < 12000) ORT_ENFORCE(beta_ == 0, "CUDA < 12.0 does not support bias, beta must be 0."); #endif diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index df25342342cd5..064b6dd392437 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -28,7 +28,7 @@ int32_t TypeSize(int32_t element_type) { case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: return 2; -#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) +#if !defined(DISABLE_FLOAT8_TYPES) case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: return 1; @@ -97,12 +97,16 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { } auto first_type = input_A->GetElementType(); +#if !defined(DISABLE_FLOAT8_TYPES) bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; if (!is_float8) +#endif return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, input_C, scale_A, scale_B, scale_Y); +#if !defined(DISABLE_FLOAT8_TYPES) return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, input_C, scale_A, scale_B, scale_Y); +#endif } Status GemmFloat8::ComputeRowMajor( @@ -197,10 +201,15 @@ Status GemmFloat8::ComputeGemm( switch (d_cuda_type) { case CUDA_R_16F: switch (a_cuda_type) { +#if !defined(DISABLE_FLOAT8_TYPES) +#if CUDA_VERSION < 11080 +#error CUDA_R_8F_E4M3 (float 8 types) is defined with CUDA>=11.8. Set flag DISABLE_FLOAT8_TYPES. +#endif case CUDA_R_8F_E4M3: case CUDA_R_8F_E5M2: compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; break; +#endif default: compute_type = CUBLAS_COMPUTE_32F_FAST_16F; break; @@ -242,15 +251,21 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb))); +#if CUDA_VERSION >= 11060 + // CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET exists from https://docs.nvidia.com/cuda/archive/11.6.0/pdf/CUBLAS_Library.pdf if (sm_count_ != 0) { int math_sm_count = static_cast(sm_count_); CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, sizeof(math_sm_count))); } +#endif if (has_scales) { // gemm float 8 +#if CUDA_VERSION >= 11080 + // CUBLASLT_MATMUL_DESC_FAST_ACCUM, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + // CUBLASLT_MATMUL_DESC_D_SCALE_POINTER exist from https://docs.nvidia.com/cuda/archive/11.8.0/pdf/CUBLAS_Library.pdf const int8_t ifast_accumulation_mode = 1; CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, @@ -265,9 +280,10 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, sizeof(p_scale_b))); +#endif // float 8 -#if CUDA_VERSION >= 11080 +#if !defined(DISABLE_FLOAT8_TYPES) if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) { // For FP8 output, cuBLAS requires C_type to be same as bias_type @@ -280,15 +296,14 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR( cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); } - } else { - CUBLAS_RETURN_IF_ERROR( - cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); - } #else - // An output is still needed but it is not initialized. CUBLAS_RETURN_IF_ERROR( cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); #endif + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } if (row_major_compute) { cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; @@ -345,7 +360,7 @@ Status GemmFloat8::ComputeGemm( ". Check NVIDIA documentation to see what combination is valid: ", "https://docs.nvidia.com/cuda/cublas/" "index.html?highlight=cublasLtMatmulAlgoGetHeuristic#" - "cublasltmatmulalgogetheuristic."); + "cublasltmatmulalgogetheuristic. CUDA>=11.8 is required to use float 8 types."); void* workspace = nullptr; if (workspaceSize > 0) { @@ -381,7 +396,8 @@ Status GemmFloat8::ComputeGemm( ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", shape_B[1], ", M=", M, ", N=", N, ", K=", K, ", lda=", lda, ", ldb=", ldb, ", ldd=", ldd, ", workspaceSize=", workspaceSize, - ", rowMajorCompute=", (row_major_compute ? 1 : 0), "."); + ", rowMajorCompute=", (row_major_compute ? 1 : 0), + ". CUDA>=11.8 is required to use float 8 types."); if (workspaceSize > 0) { CUDA_RETURN_IF_ERROR(cudaFree(workspace)); diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 398ce4ee9880f..f4f2b49032d23 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #include #include @@ -501,8 +503,27 @@ __global__ void compute_total_rows_before_expert_kernel(const int* sorted_expert total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); } +__global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, int num_experts, + int local_num_experts, int local_experts_start_index) { + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + const int local_experts_end_index = local_experts_start_index + local_num_experts - 1; + + int total_past_rows = 0; + if (local_experts_start_index > 0) { + total_past_rows = total_rows_before_expert[local_experts_start_index - 1]; + } + + if (expert < local_experts_start_index || expert > local_experts_end_index) { + return; + } + + total_rows_before_expert[expert] -= total_past_rows; +} + template CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version) { + total_past_rows_ = 0; + total_covered_rows_ = 0; moe_gemm_runner_.initialize(sm_version); } @@ -549,7 +570,6 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size)); const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts)); const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows)); - // const int num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts); source_rows_ = (int*)ws_ptr; permuted_rows_ = source_rows_ + num_moe_inputs; @@ -573,8 +593,9 @@ void CutlassMoeFCRunner::run_moe_fc( const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, - int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, T* expert_scales, - int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { + int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, + const bool* finished, int active_rows, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, + int* expert_for_source_row, cudaStream_t stream) { static constexpr bool scales_required = std::is_same::value || std::is_same::value; @@ -608,12 +629,23 @@ void CutlassMoeFCRunner::run_moe_fc( compute_total_rows_before_expert(permuted_experts_, expanded_active_expert_rows, num_experts, total_rows_before_expert_, stream); - moe_gemm_runner_.moe_gemm_bias_act(permuted_data_, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_, - total_rows_before_expert_, expanded_active_expert_rows, inter_size, hidden_size, - num_experts, fc1_activation_type, stream); + if (local_num_experts < num_experts) { + dispatch_activations(total_rows_before_expert_, num_experts, local_num_experts, local_experts_start_index, stream); + } - moe_gemm_runner_.moe_gemm(fc1_result_, fc2_expert_weights, fc2_scales, fc2_result, total_rows_before_expert_, - expanded_active_expert_rows, hidden_size, inter_size, num_experts, stream); + // expanded_active_expert_rows is not used + moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size, + fc1_expert_weights, fc1_scales, fc1_expert_biases, + fc1_result_ + total_past_rows_ * inter_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, inter_size, hidden_size, + local_num_experts, fc1_activation_type, stream); + + moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size, + fc2_expert_weights, fc2_scales, + fc2_result + total_past_rows_ * hidden_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, hidden_size, inter_size, local_num_experts, stream); } template @@ -621,12 +653,12 @@ void CutlassMoeFCRunner::run_moe_fc( const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, - int k, char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, - int* expert_for_source_row, cudaStream_t stream) { + int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales, + int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type, - fc2_expert_weights, fc2_scales, num_rows, hidden_size, inter_size, num_experts, k, workspace_ptr, - fc2_result, nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, - expert_for_source_row, stream); + fc2_expert_weights, fc2_scales, num_rows, hidden_size, inter_size, num_experts, local_num_experts, + local_experts_start_index, k, workspace_ptr, fc2_result, nullptr, num_rows, expert_scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, stream); } template @@ -642,6 +674,44 @@ void CutlassMoeFCRunner::compute_total_rows_before_expert total_rows_before_expert); } +template +void CutlassMoeFCRunner::dispatch_activations(int64_t* total_rows_before_expert, + int num_experts, int local_num_experts, + int local_experts_start_index, + cudaStream_t stream) { + total_rows_before_expert_host_.resize(num_experts); + cudaMemcpyAsync(total_rows_before_expert_host_.data(), total_rows_before_expert, num_experts * sizeof(int64_t), + cudaMemcpyDeviceToHost, stream); + + const int threads = std::min(1024, num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + cudaEvent_t& copy_event = cuda_event_.Get(); + cudaEventCreateWithFlags(©_event, cudaEventDisableTiming); + cudaEventRecord(copy_event, stream); + + dispatch_activations_kernel<<>>(total_rows_before_expert, num_experts, + local_num_experts, local_experts_start_index); + + get_total_rows_info(local_experts_start_index, local_num_experts, total_past_rows_, total_covered_rows_); +} + +template +void CutlassMoeFCRunner::get_total_rows_info(int64_t experts_start_index, + int64_t local_num_experts, + int64_t& total_past_rows, + int64_t& total_covered_rows) { + int64_t experts_end_index = experts_start_index + local_num_experts - 1; + total_past_rows = 0; + + cudaEventSynchronize(cuda_event_.Get()); + + if (experts_start_index > 0) { + total_past_rows = total_rows_before_expert_host_[experts_start_index - 1]; + } + total_covered_rows = total_rows_before_expert_host_[experts_end_index] - total_past_rows; +} + // ========================== Permutation things ======================================= // Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index 5cefe4fa5dc47..5cc2a3f79f003 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #pragma once @@ -20,6 +22,7 @@ #include #include "core/common/common.h" +#include "contrib_ops/cuda/bert/transformer_cuda_common.h" using namespace onnxruntime; @@ -111,20 +114,26 @@ class CutlassMoeFCRunner { void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, - int inter_size, int num_experts, int k, char* workspace_ptr, T* fc2_result, - T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, - cudaStream_t stream); + int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k, + char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, + int* expert_for_source_row, cudaStream_t stream); void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, - int inter_size, int num_experts, int k, char* workspace_ptr, T* fc2_result, - const bool* finished, int active_rows, T* expert_scales, + int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k, + char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream); void compute_total_rows_before_expert(const int* sorted_indices, int total_indices, int num_experts, int64_t* total_rows_before_expert, cudaStream_t stream); + void dispatch_activations(int64_t* total_rows_before_expert, int num_experts, int local_num_experts, + int local_experts_start_index, cudaStream_t stream); + + void get_total_rows_info(int64_t experts_start_index, int64_t local_num_experts, int64_t& total_past_rows, + int64_t& total_covered_rows); + private: void configure_ws_ptrs(char* ws_ptr, int num_rows, int hidden_size, int inter_size, int num_experts, int k); @@ -143,6 +152,14 @@ class CutlassMoeFCRunner { int64_t* total_rows_before_expert_; T* fc1_result_; + + // Cuda events + contrib::cuda::AutoDestoryCudaEvent cuda_event_; + + int64_t total_past_rows_; + int64_t total_covered_rows_; + // TODO: use pinned memory + std::vector total_rows_before_expert_host_; }; template diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 6f2ffe7a0cc43..3f26a274109ad 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -30,6 +30,10 @@ REGISTER_KERNEL_TYPED(MLFloat16) using namespace ONNX_NAMESPACE; +template +MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { +} + template Status MoE::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); @@ -39,95 +43,9 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { const Tensor* fc1_experts_bias_optional = context->Input(4); const Tensor* fc2_experts_bias_optional = context->Input(5); - const auto& input_dims = input->Shape().GetDims(); - const auto& router_probs_dims = router_probs->Shape().GetDims(); - const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); - const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); - - const int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; - const int64_t hidden_size = input_dims[input_dims.size() - 1]; - const int64_t num_experts = fc1_experts_weights_dims[0]; - const int64_t inter_size = fc1_experts_weights_dims[2]; - - // TODO: refactor to helper function. - if (fc1_experts_weights_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", - fc1_experts_weights_dims.size()); - } - if (fc2_experts_weights_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ", - fc2_experts_weights_dims.size()); - } - if (fc1_experts_weights_dims[1] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_weights_dims[1] must be equal to hidden_size, got ", - fc1_experts_weights_dims[1], " and ", hidden_size); - } - if (fc2_experts_weights_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_weights_dims[1] must be equal to inter_size, got ", fc2_experts_weights_dims[1], - " and ", inter_size); - } - if (fc1_experts_weights_dims[2] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_weights_dims[2] must be equal to inter_size, got ", fc1_experts_weights_dims[2], - " and ", inter_size); - } - if (fc2_experts_weights_dims[2] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", - fc2_experts_weights_dims[2], " and ", hidden_size); - } - if (router_probs_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", - router_probs_dims.size()); - } - if (router_probs_dims[0] != num_rows) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ", - router_probs_dims[0], " and ", num_rows); - } - if (router_probs_dims[1] != num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[1] must be equal to num_experts, got ", - router_probs_dims[1], " and ", num_experts); - } - if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is set but fc2_experts_bias is not set"); - } - if (fc1_experts_bias_optional == nullptr && fc2_experts_bias_optional != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is not set but fc2_experts_bias is set"); - } - if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) { - const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims(); - const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims(); - if (fc1_experts_bias_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ", - fc1_experts_bias_dims.size()); - } - if (fc2_experts_bias_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ", - fc2_experts_bias_dims.size()); - } - if (fc1_experts_bias_dims[0] != num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[0] must be equal to num_experts, got ", fc1_experts_bias_dims[0], - " and ", num_experts); - } - if (fc2_experts_bias_dims[0] != num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0], - " and ", num_experts); - } - if (fc1_experts_bias_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1], - " and ", inter_size); - } - if (fc2_experts_bias_dims[1] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", fc2_experts_bias_dims[1], - " and ", hidden_size); - } - } + MoEParameters moe_params; + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights, + fc1_experts_bias_optional, fc2_experts_bias_optional)); typedef typename ToCudaType::MappedType CudaT; auto stream = context->GetComputeStream(); @@ -138,12 +56,13 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm); size_t ws_size = - moe_runner.getWorkspaceSize(static_cast(num_rows), static_cast(hidden_size), - static_cast(inter_size), static_cast(num_experts), static_cast(k_)); - size_t fc2_output_size = k_ * num_rows * hidden_size * sizeof(CudaT); - size_t expert_scales_size = k_ * num_rows * sizeof(CudaT); - size_t expanded_source_row_to_expanded_dest_row_size = k_ * num_rows * sizeof(int); - size_t expert_for_source_row_size = k_ * num_rows * sizeof(int); + moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(k_)); + size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); + size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); + size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); + size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -170,8 +89,10 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { ? nullptr : reinterpret_cast(fc1_experts_bias_optional->template Data()), activation_type_, reinterpret_cast(fc2_experts_weights->template Data()), - std::move(fc2_scales_ptr), static_cast(num_rows), static_cast(hidden_size), - static_cast(inter_size), static_cast(num_experts), static_cast(k_), + std::move(fc2_scales_ptr), static_cast(moe_params.num_rows), + static_cast(moe_params.hidden_size), static_cast(moe_params.inter_size), + static_cast(moe_params.num_experts), static_cast(moe_params.local_num_experts), + 0 /*local_experts_start_index_ used in sharded MoE*/, static_cast(k_), reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), reinterpret_cast(expert_scales.get()), reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), @@ -186,7 +107,8 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { : reinterpret_cast(fc2_experts_bias_optional->template Data()), reinterpret_cast(expert_scales.get()), reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), - reinterpret_cast(expert_for_source_row.get()), static_cast(num_rows), static_cast(hidden_size), + reinterpret_cast(expert_for_source_row.get()), static_cast(moe_params.num_rows), + static_cast(moe_params.hidden_size), static_cast(k_), Stream(context)); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h index 8035568693814..c4d8c4dc64c57 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -4,6 +4,7 @@ #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" +#include "contrib_ops/cuda/moe/moe_base.h" #include "core/common/common.h" #include "core/providers/cuda/cuda_kernel.h" @@ -14,30 +15,10 @@ namespace cuda { using namespace onnxruntime::cuda; template -class MoE final : public CudaKernel { +class MoE final : public CudaKernel, public MoEBase { public: - explicit MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { - ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); - - std::string activation_type_str; - ORT_ENFORCE(op_kernel_info.GetAttr("activation_type", &activation_type_str).IsOK()); - if (activation_type_str == "relu") { - activation_type_ = ort_fastertransformer::ActivationType::Relu; - } else if (activation_type_str == "gelu") { - activation_type_ = ort_fastertransformer::ActivationType::Gelu; - } else if (activation_type_str == "silu") { - activation_type_ = ort_fastertransformer::ActivationType::Silu; - } else if (activation_type_str == "identity") { - activation_type_ = ort_fastertransformer::ActivationType::Identity; - } else { - ORT_THROW("Unsupported MoE activation type: ", activation_type_str); - } - } + explicit MoE(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* ctx) const override; - - private: - int64_t k_; - ort_fastertransformer::ActivationType activation_type_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h new file mode 100644 index 0000000000000..f55a7cde2e208 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +enum class MoEParallelType { + None = 0, + ExpertSlicing = 1, +}; + +struct MoEParameters { + int64_t num_rows; + int64_t num_experts; + int64_t local_num_experts; + int64_t hidden_size; + int64_t inter_size; + MoEParallelType parallel_type; +}; + +class MoEBase { + public: + Status CheckInputs(MoEParameters& parameters, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc2_experts_weights, + const Tensor* fc1_experts_bias_optional, + const Tensor* fc2_experts_bias_optional) const { + const auto& input_dims = input->Shape().GetDims(); + const auto& router_probs_dims = router_probs->Shape().GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); + const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); + + int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; + int64_t hidden_size = input_dims[input_dims.size() - 1]; + int64_t local_num_experts = fc1_experts_weights_dims[0]; + int64_t num_experts = router_probs_dims[1]; + int64_t inter_size = fc1_experts_weights_dims[2]; + + if (fc1_experts_weights_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", + fc1_experts_weights_dims.size()); + } + if (fc2_experts_weights_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ", + fc2_experts_weights_dims.size()); + } + if (fc1_experts_weights_dims[1] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_weights_dims[1] must be equal to hidden_size, got ", + fc1_experts_weights_dims[1], " and ", hidden_size); + } + if (fc2_experts_weights_dims[1] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_weights_dims[1] must be equal to inter_size, got ", + fc2_experts_weights_dims[1], + " and ", inter_size); + } + if (fc1_experts_weights_dims[2] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_weights_dims[2] must be equal to inter_size, got ", + fc1_experts_weights_dims[2], + " and ", inter_size); + } + if (fc2_experts_weights_dims[2] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", + fc2_experts_weights_dims[2], " and ", hidden_size); + } + if (router_probs_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", + router_probs_dims.size()); + } + if (router_probs_dims[0] != num_rows) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ", + router_probs_dims[0], " and ", num_rows); + } + if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is set but fc2_experts_bias is not set"); + } + if (fc1_experts_bias_optional == nullptr && fc2_experts_bias_optional != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is not set but fc2_experts_bias is set"); + } + if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) { + const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims(); + const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims(); + if (fc1_experts_bias_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ", + fc1_experts_bias_dims.size()); + } + if (fc2_experts_bias_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ", + fc2_experts_bias_dims.size()); + } + if (fc1_experts_bias_dims[0] != local_num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_bias_dims[0] must be equal to local_num_experts, got ", + fc1_experts_bias_dims[0], + " and ", local_num_experts); + } + if (fc2_experts_bias_dims[0] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_bias_dims[0] must be equal to num_experts, got ", + fc2_experts_bias_dims[0], + " and ", num_experts); + } + if (fc1_experts_bias_dims[1] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_bias_dims[1] must be equal to inter_size, got ", + fc1_experts_bias_dims[1], + " and ", inter_size); + } + if (fc2_experts_bias_dims[1] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", + fc2_experts_bias_dims[1], + " and ", hidden_size); + } + } + + parameters.num_rows = num_rows; + parameters.num_experts = num_experts; + parameters.local_num_experts = local_num_experts; + parameters.hidden_size = hidden_size; + parameters.inter_size = inter_size; + if (num_experts == local_num_experts) { + parameters.parallel_type = MoEParallelType::None; + } else if (num_experts > local_num_experts) { + parameters.parallel_type = MoEParallelType::ExpertSlicing; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_experts must be greater than or equal to local_num_experts, got ", + num_experts, " and ", local_num_experts); + } + + return Status::OK(); + } + + protected: + MoEBase(const OpKernelInfo& op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); + + std::string activation_type_str; + ORT_ENFORCE(op_kernel_info.GetAttr("activation_type", &activation_type_str).IsOK()); + if (activation_type_str == "relu") { + activation_type_ = ort_fastertransformer::ActivationType::Relu; + } else if (activation_type_str == "gelu") { + activation_type_ = ort_fastertransformer::ActivationType::Gelu; + } else if (activation_type_str == "silu") { + activation_type_ = ort_fastertransformer::ActivationType::Silu; + } else if (activation_type_str == "identity") { + activation_type_ = ort_fastertransformer::ActivationType::Identity; + } else { + ORT_THROW("Unsupported MoE activation type: ", activation_type_str); + } + } + + int64_t k_; + ort_fastertransformer::ActivationType activation_type_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 7921315ab52e1..6b66f1d84e221 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -64,8 +64,12 @@ __global__ void Dequantize4BitsKernel( int block_size, int blocks_per_K, int blocks_per_threadblock, + int total_blks, int shift) { int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); + if (block_id >= total_blks) { + return; + } int n_idx = block_id / blocks_per_K; int kb_idx = block_id % blocks_per_K; int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); @@ -96,6 +100,7 @@ Status Dequantize4Bits( constexpr int element_per_thread = 8; int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; int blocks_per_K = k / block_size; + int total_blks = n * blocks_per_K; int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); int shift = static_cast(log2f(float(block_size))); @@ -107,6 +112,7 @@ Status Dequantize4Bits( block_size, blocks_per_K, blocks_per_threadblock, + total_blks, shift); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu index e58723f0b31e1..2f74dd41f0759 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu @@ -35,6 +35,8 @@ template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, c template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream); +template Status SetBnbQuantMap(int quant_type, BFloat16* quant_map_buffer, cudaStream_t stream); + template __global__ void kDequantizeBlockwise( const T* quant_map, @@ -62,22 +64,15 @@ __global__ void kDequantizeBlockwise( valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; - local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]); + local_abs_max = absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]; __syncthreads(); LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128); #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max; - vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max; - #else - // half multiplication not supported - vals[j * 2] = static_cast(static_cast(quant_map[qvals[j] >> 4]) * static_cast(local_abs_max)); - vals[j * 2 + 1] = - static_cast(static_cast(quant_map[qvals[j] & 0x0F]) * static_cast(local_abs_max)); - #endif + vals[j * 2] = ScalarMul(quant_map[qvals[j] >> 4], local_abs_max); + vals[j * 2 + 1] = ScalarMul(quant_map[qvals[j] & 0x0F], local_abs_max); } __syncthreads(); @@ -86,7 +81,7 @@ __global__ void kDequantizeBlockwise( } template -Status DequantizeBnb4( +void CallkDequantizeBlockwise( const T* quant_map, T* output, const uint8_t* quant_data, @@ -102,6 +97,18 @@ Status DequantizeBnb4( absmax, block_size / 2, numel); +} + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream) { + CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream); return Status::OK(); } @@ -119,11 +126,36 @@ template Status DequantizeBnb4( const half* quant_map, half* output, const uint8_t* quant_data, - const half *absmax, + const half* absmax, int block_size, int numel, cudaStream_t stream); +template <> +Status DequantizeBnb4( + const BFloat16* quant_map, + BFloat16* output, + const uint8_t* quant_data, + const BFloat16* absmax, + int block_size, + int numel, + cudaStream_t stream) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + CallkDequantizeBlockwise( + reinterpret_cast(quant_map), + reinterpret_cast(output), + quant_data, + reinterpret_cast(absmax), + block_size, + numel, + stream); + #else + CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream); + #endif + + return Status::OK(); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh index 4aef3ab699f9c..a0d38c9853cd6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh @@ -11,6 +11,38 @@ namespace cuda { template Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream); +// templated scalar multiply function +template +__device__ inline T ScalarMul(T a, T b); + +template <> +__device__ inline float ScalarMul(float a, float b) { + return a * b; +} + +template <> +__device__ inline half ScalarMul(half a, half b) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return a * b; + #else + // half multiplication not supported + return static_cast(static_cast(a) * static_cast(b)); + #endif +} + +template <> +__device__ inline BFloat16 ScalarMul(BFloat16 a, BFloat16 b) { + return a * b; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// will use the native bfloat16 multiply instruction on sm_80+ +template <> +__device__ inline nv_bfloat16 ScalarMul(nv_bfloat16 a, nv_bfloat16 b) { + return a * b; +} +#endif + template Status DequantizeBnb4( const T* quant_map, diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index ecf332715d470..bbcb7de99781f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -145,6 +145,17 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( .TypeConstraint("T2", DataTypeImpl::GetTensorType()), MatMulBnb4); +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu index 1d9aa75ff3701..098e3618beddd 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu @@ -6,12 +6,44 @@ #include #include #include +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" #include "matmul_bnb4.cuh" namespace onnxruntime { namespace contrib { namespace cuda { +template +__device__ inline float ScalarMulFloatOut(T a, T b); + +template <> +__device__ inline float ScalarMulFloatOut(float a, float b) { + return a * b; +} + +template <> +__device__ inline float ScalarMulFloatOut(half a, half b) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return static_cast(a * b); + #else + // half multiplication not supported + return static_cast(a) * static_cast(b); + #endif +} + +template <> +__device__ inline float ScalarMulFloatOut(BFloat16 a, BFloat16 b) { + return a * b; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// will use the native bfloat16 multiply instruction on sm_80+ +template <> +__device__ inline float ScalarMulFloatOut(nv_bfloat16 a, nv_bfloat16 b) { + return static_cast(a * b); +} +#endif + #define num_values_4bit 32 template __global__ void kgemm_4bit_inference_naive( @@ -55,7 +87,7 @@ __global__ void kgemm_4bit_inference_naive( int inner_idx_halved = inner_idx / 2; int offset_B = ldb * row_B; int absidx = ((2 * offset_B) + inner_idx) / block_size; - local_absmax = __ldg(&(absmax[absidx])); + local_absmax = absmax[absidx]; if (row_B < N) { if ((inner_idx_halved + num_values_8bit) < (K / 2)) { @@ -78,18 +110,8 @@ __global__ void kgemm_4bit_inference_naive( for (int i = 0; i < 4; i++) { #pragma unroll for (int k = 0; k < num_values_8bit / 4; k++) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; - local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; - #else - // half multiplication not supported - local_B[k * 2] = - static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) * - static_cast(local_absmax)); - local_B[k * 2 + 1] = - static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) * - static_cast(local_absmax)); - #endif + local_B[k * 2] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4], local_absmax); + local_B[k * 2 + 1] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F], local_absmax); } if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { @@ -116,12 +138,7 @@ __global__ void kgemm_4bit_inference_naive( // accumulate in float; small performance hit for Ampere, but lower error for outputs #pragma unroll for (int k = 0; k < num_values_4bit / 4; k++) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - local_C += static_cast(local_A[k] * local_B[k]); - #else - // half multiplication not supported - local_C += static_cast(local_A[k]) * static_cast(local_B[k]); - #endif + local_C += ScalarMulFloatOut(local_A[k], local_B[k]); } } } @@ -131,8 +148,19 @@ __global__ void kgemm_4bit_inference_naive( if (row_B < N && warp_lane == 0) out[row_B] = T(local_C); } +bool CheckDims(int m, int k, int block_size) { + if (k % block_size != 0 || m > 1) { + return false; + } + // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] + if (block_size % 32 != 0 || block_size > 4096) { + return false; + } + return true; +} + template -bool TryMatMulBnb4( +void Callkgemm_4bit_inference_naive( const T* quant_map, T* output, const T* a_data, @@ -143,22 +171,34 @@ bool TryMatMulBnb4( int k, int block_size, cudaStream_t stream) { - if (k % block_size != 0 || m > 1) { - return false; - } - // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] - if (block_size % 32 != 0 || block_size > 4096) { - return false; - } - int lda = k; int ldb = (k + 1) / 2; int ldc = n; int num_blocks = (n + 3) / 4; - constexpr int bits = std::is_same_v ? 16 : 32; + constexpr int bits = std::is_same_v ? 32 : 16; kgemm_4bit_inference_naive<<>>( m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size); +} + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (!CheckDims(m, k, block_size)) { + return false; + } + + Callkgemm_4bit_inference_naive( + quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream); return true; } @@ -187,6 +227,42 @@ template bool TryMatMulBnb4( int block_size, cudaStream_t stream); +template <> +bool TryMatMulBnb4( + const BFloat16* quant_map, + BFloat16* output, + const BFloat16* a_data, + const uint8_t* b_data_quant, + const BFloat16* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (!CheckDims(m, k, block_size)) { + return false; + } + + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + Callkgemm_4bit_inference_naive( + reinterpret_cast(quant_map), + reinterpret_cast(output), + reinterpret_cast(a_data), + b_data_quant, + reinterpret_cast(absmax), + m, + n, + k, + block_size, + stream); + #else + Callkgemm_4bit_inference_naive( + quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream); + #endif + + return true; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/image_scaler.cc b/onnxruntime/contrib_ops/cuda/tensor/image_scaler.cc index a2169b29dc8f5..befad5661c43f 100644 --- a/onnxruntime/contrib_ops/cuda/tensor/image_scaler.cc +++ b/onnxruntime/contrib_ops/cuda/tensor/image_scaler.cc @@ -26,8 +26,8 @@ REGISTER_KERNEL_TYPED(MLFloat16) template ImageScaler::ImageScaler(const OpKernelInfo& info) : CudaKernel(info) { - ORT_ENFORCE(info.GetAttr("scale", &scale_).IsOK()); - ORT_ENFORCE(info.GetAttrs("bias", bias_).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttr("scale", &scale_)); + ORT_THROW_IF_ERROR(info.GetAttrs("bias", bias_)); b_data_ = GetScratchBuffer(bias_.size(), nullptr); // the transfer in kernel construction need to be sync on default stream. diff --git a/onnxruntime/contrib_ops/js/bert/attention.cc b/onnxruntime/contrib_ops/js/bert/attention.cc new file mode 100644 index 0000000000000..723ff00aa815e --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + Attention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + Attention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/attention.h b/onnxruntime/contrib_ops/js/bert/attention.h new file mode 100644 index 0000000000000..0fa823befa9b2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/attention.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_base.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::contrib::AttentionBase; +using onnxruntime::js::JsKernel; + +class Attention : public JsKernel, AttentionBase { + public: + explicit Attention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) { + std::vector qkv_sizes(qkv_hidden_sizes_.size()); + if (qkv_hidden_sizes_.size() > 0) { + std::transform(qkv_hidden_sizes_.begin(), qkv_hidden_sizes_.end(), qkv_sizes.begin(), + [](int64_t sz) { return gsl::narrow_cast(sz); }); + } + + JSEP_INIT_KERNEL_ATTRIBUTE(Attention, ({ + "numHeads" : $1, + "isUnidirectional" : $2, + "maskFilterValue" : $3, + "scale" : $4, + "doRotary" : $5, + "qkvHiddenSizes" : $6 ? (Array.from(HEAP32.subarray(Number($7), Number($7) + $6))) : [], + "pastPresentShareBuffer" : !!$8, + }), + static_cast(num_heads_), + static_cast(is_unidirectional_), + static_cast(mask_filter_value_), + static_cast(scale_), + static_cast(do_rotary_), + static_cast(qkv_hidden_sizes_.size()), + reinterpret_cast((qkv_sizes.size() > 0) ? qkv_sizes.data() : nullptr) >> 2, + static_cast(past_present_share_buffer_)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc new file mode 100644 index 0000000000000..c43f8b7f18465 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "multi_head_attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MultiHeadAttention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + MultiHeadAttention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/multi_head_attention.h b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h new file mode 100644 index 0000000000000..6c63a2ffed4b2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_base.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::contrib::AttentionBase; +using onnxruntime::js::JsKernel; + +class MultiHeadAttention : public JsKernel, AttentionBase { + public: + explicit MultiHeadAttention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) { + JSEP_INIT_KERNEL_ATTRIBUTE(MultiHeadAttention, ({ + "numHeads" : $1, + "isUnidirectional" : $2, + "maskFilterValue" : $3, + "scale" : $4, + "doRotary" : $5, + }), + static_cast(num_heads_), + static_cast(is_unidirectional_), + static_cast(mask_filter_value_), + static_cast(scale_), + static_cast(do_rotary_)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 24d327576ecd9..498a9f5679eb5 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -7,7 +7,9 @@ namespace onnxruntime { namespace contrib { namespace js { +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); @@ -21,7 +23,9 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo +template auto GetCKGemmAddFastGeluTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; + using ALayout = typename CKBlasOpAdaptor::type; + using BLayout = typename CKBlasOpAdaptor::type; using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< ALayout, BLayout, ck::Tuple, Row, CKDataType, CKDataType, ck::Tuple, CKDataType, @@ -76,9 +79,11 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() { return ret; } -template +template auto GetCKGemmFastGeluTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; + using ALayout = typename CKBlasOpAdaptor::type; + using BLayout = typename CKBlasOpAdaptor::type; using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< ALayout, BLayout, ck::Tuple<>, Row, CKDataType, CKDataType, ck::Tuple<>, CKDataType, diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu index 294e7be91e883..8d7e64b1015be 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu @@ -49,16 +49,16 @@ inline GEMMFASTGELU(T, ScalarT) { if (tuning_ctx->IsTunableOpEnabled()) { if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; return gemm_fast_gelu(¶ms); } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; return gemm_fast_gelu(¶ms); } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; return gemm_fast_gelu(¶ms); } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; return gemm_fast_gelu(¶ms); } } diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh index 229f868a215fd..e157aa57f8c43 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh @@ -51,24 +51,24 @@ Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { params->c); } -template +template class GemmFastGeluTunableOp : public TunableOp> { public: GemmFastGeluTunableOp() { this->RegisterOp(GemmFastGeluUnfused); #ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } #endif #ifdef USE_HIPBLASLT - for (auto&& [_, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { + for (auto&& [_, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index 0146e81c6cf8c..fb7091592c16e 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -34,17 +34,17 @@ constexpr int NumReduceDim = 3; template auto GetCKGroupNormNHWCTypeStringAndOps() { - using InDataType = typename CKDataTypeAdaptor::type; - using OutDataType = typename CKDataTypeAdaptor::type; - using AccDataType = typename CKDataTypeAdaptor::type; + using XDataType = typename CKDataTypeAdaptor::type; + using YDataType = typename CKDataTypeAdaptor::type; + using SaveMeanInvStdDataType = typename CKDataTypeAdaptor::type; using GammaDataType = float; using BetaDataType = float; using Activation = std::conditional_t; std::vector>>> ret; - for (auto&& impl : internal::GetDeviceGroupNormInstances()) { + for (auto&& impl : internal::GetDeviceGroupNormInstances()) { std::string swish_suffix = WithSwish ? "_Swish" : "_Pass"; auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix; auto invoker = impl->MakeInvokerPointer(); @@ -69,6 +69,8 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { gamma_beta_strides, // gammaStrides gamma_beta_strides, // betaStrides in_out_strides, // yStrides + {0, 0}, // saveMeanStrides + {0, 0}, // saveInvStdStrides reduce_dims, // reduceDims params->epsilon, params->src, diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh index 88443478cf521..19b081881dcec 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -6,8 +6,8 @@ #ifdef USE_COMPOSABLE_KERNEL #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" #include "ck/utility/data_type.hpp" namespace onnxruntime { @@ -21,102 +21,104 @@ using F32 = float; using Swish = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; -using ck::tensor_operation::device::DeviceNormalization; // the interface -using ck::tensor_operation::device::DeviceNormalizationImpl; // the implementation +using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface +using ck::tensor_operation::device::DeviceNormalizationFwdImpl; // the implementation + +// See https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/1fefd82ed8/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp template using device_normalization_f32_instances = std::tuple< // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl // clang-format on >; template -using device_normalization_f16_instances = std::tuple< +using device_normalization_f16_instances = // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + std::tuple < + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl // clang-format on >; // Use this function to get implementation -template -std::vector>> +std::vector>> GetDeviceGroupNormInstances() { return {}; } template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances< - F16, F32, F32, F32, F16, Swish, 5, 3>(); + F16, F32, F32, F16, F32, Swish, 5, 3>(); template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances< - F16, F32, F32, F32, F16, Pass, 5, 3>(); + F16, F32, F32, F16, F32, Pass, 5, 3>(); template <> -std::vector>> GetDeviceGroupNormInstances< F32, F32, F32, F32, F32, Swish, 5, 3>(); template <> -std::vector>> GetDeviceGroupNormInstances< F32, F32, F32, F32, F32, Pass, 5, 3>(); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu index d1dd78e3452da..6718f29268031 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu @@ -4,7 +4,6 @@ #ifdef USE_COMPOSABLE_KERNEL #include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" namespace onnxruntime { namespace contrib { @@ -12,9 +11,9 @@ namespace rocm { namespace internal { template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f16_instances{}); @@ -23,9 +22,9 @@ GetDeviceGroupNormInstances() { } template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f16_instances{}); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu index 97baed34a341d..9b0ccab17b4c1 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -4,7 +4,6 @@ #ifdef USE_COMPOSABLE_KERNEL #include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" namespace onnxruntime { namespace contrib { @@ -12,9 +11,9 @@ namespace rocm { namespace internal { template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances() { - std::vector>> instances; + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f32_instances{}); @@ -23,9 +22,9 @@ GetDeviceGroupNormInstances() { } template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances() { - std::vector>> instances; + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f32_instances{}); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index 526d220d4be24..b7b9441ac997d 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -77,7 +77,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { params->epsilon}; // Grid dim is (batch_count, groups, 1) - return LaunchTritonKernel(params->stream, i, params->n, params->groups, 1, &args, sizeof(args)); + return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); }; ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); } diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu new file mode 100644 index 0000000000000..1e175b37b02d8 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/framework/float16.h" +#include "core/providers/rocm/rocm_kernel.h" +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; +using namespace onnxruntime::rocm::tunable::blas; + +class GemmFloat8 final : public RocmKernel { + public: + GemmFloat8(const OpKernelInfo& info) : RocmKernel(info) { + transA_ = info.GetAttrOrDefault("transA", 0); + transB_ = info.GetAttrOrDefault("transB", 0); + dtype_ = info.GetAttrOrDefault("dtype", onnx::TensorProto_DataType_FLOAT16); + alpha_ = info.GetAttrOrDefault("alpha", 1); + beta_ = info.GetAttrOrDefault("beta", 0); + } + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: +#if !defined(DISABLE_FLOAT8_TYPES) + template + Status ComputeFp8Fp16Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* scaleA, const Tensor* B, Tensor* C) const; + template + Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const; + + template + [[nodiscard]] inline auto* GetOp() const { + using OpT = GemmFloat8TunableOp; + if (tunable_op_) { + return static_cast(tunable_op_.get()); + } + + auto create = std::make_unique(); // avoid new + tunable_op_ = std::shared_ptr(create.release(), [](void* ptr) { + auto release = std::unique_ptr(); // avoid delete + release.reset(static_cast(ptr)); + }); + + return static_cast(tunable_op_.get()); + } +#endif + + float alpha_; + float beta_; + bool transA_; + bool transB_; + int64_t dtype_; + + // fully type erased + mutable std::shared_ptr tunable_op_; +}; + +Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { +#if defined(DISABLE_FLOAT8_TYPES) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DISABLE_FLOAT8_TYPES"); +#else + const Tensor* A = ctx->Input(0); + const Tensor* B = ctx->Input(1); + const Tensor* C = ctx->Input(2); // bias + const Tensor* scale_a = ctx->Input(3); + const Tensor* scale_b = ctx->Input(4); + const Tensor* scale_y = ctx->Input(5); + + auto a_shape = A->Shape(); + auto b_shape = B->Shape(); + ORT_ENFORCE(a_shape.NumDimensions() == 2); + ORT_ENFORCE(b_shape.NumDimensions() == 2); + + auto m = !transA_ ? a_shape[0] : a_shape[1]; + auto k = !transA_ ? a_shape[1] : a_shape[0]; + ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatiable + auto n = !transB_ ? b_shape[1] : b_shape[0]; + + TensorShapeVector output_shape = {m, n}; + Tensor* Y = ctx->Output(0, output_shape); + + ORT_ENFORCE(!transA_, "ROCm GemmFloat8 does not support input A transpose"); + ORT_ENFORCE(dtype_ == onnx::TensorProto_DataType_FLOAT16, "ROCm GemmFloat8 only supports output float16"); + ORT_ENFORCE(C == nullptr, "ROCm GemmFloat8 does not support bias input"); + ORT_ENFORCE(scale_y == nullptr, "ROCm GemmFloat8 does not support output scaling"); + + if (A->IsDataType()) { + return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); + } else if (A->IsDataType()) { + return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); + } else if (B->IsDataType()) { + return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); + } else if (B->IsDataType()) { + return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unhandled type combination of GemmFloat8"); +#endif +} + +#if !defined(DISABLE_FLOAT8_TYPES) +template +Status GemmFloat8::ComputeFp8Fp16Fp16( + OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* scale_a, const Tensor* B, Tensor* C) const { + ORT_ENFORCE(A->IsDataType() && scale_a->IsDataType() && B->IsDataType()); + + onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; + params.tuning_ctx = GetTuningContext(); + params.stream = ctx->GetComputeStream(); + params.handle = GetRocblasHandle(ctx); + params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + + params.m = m; + params.n = n; + params.k = k; + + params.a = static_cast(A->DataRaw()); + params.lda = transA_ ? m : k; + params.scale_a = alpha_; + params.scale_a_dev = static_cast(scale_a->DataRaw()); + + params.b = static_cast(B->DataRaw()); + params.ldb = transB_ ? k : n; + params.scale_b = 1.0f; // NOTE: not used + params.scale_b_dev = nullptr; // NOTE: not used + + params.c = static_cast(C->MutableDataRaw()); + params.ldc = n; + params.scale_c = 1.0f; // NOTE: not implemented + params.scale_c_dev = nullptr; // NOTE: not implemented + + if (!transA_ && !transB_) { + return (*GetOp())(¶ms); + } else if (transA_ && !transB_) { + ORT_NOT_IMPLEMENTED("transA is not implemented"); + } else if (!transA_ && transB_) { + ORT_NOT_IMPLEMENTED("transB is not implemented"); + } else if (transA_ && transB_) { + ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); +} + +template +Status GemmFloat8::ComputeFp16Fp8Fp16( + OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* B, const Tensor* scale_b, Tensor* C) const { + ORT_ENFORCE(A->IsDataType() && B->IsDataType() && scale_b->IsDataType()); + + onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; + params.tuning_ctx = GetTuningContext(); + params.stream = ctx->GetComputeStream(); + params.handle = GetRocblasHandle(ctx); + params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + + params.m = m; + params.n = n; + params.k = k; + + params.a = static_cast(A->DataRaw()); + params.lda = transA_ ? m : k; + params.scale_a = 1.0f; // NOTE: not used + params.scale_a_dev = nullptr; // NOTE: not used + + params.b = static_cast(B->DataRaw()); + params.ldb = transB_ ? k : n; + params.scale_b = alpha_; + params.scale_b_dev = static_cast(scale_b->DataRaw()); + + params.c = static_cast(C->MutableDataRaw()); + params.ldc = n; + params.scale_c = 1.0f; // NOTE: not implemented + params.scale_c_dev = nullptr; // NOTE: not implemented + + if (!transA_ && !transB_) { + return (*GetOp())(¶ms); + } else if (transA_ && !transB_) { + ORT_NOT_IMPLEMENTED("transA is not implemented"); + } else if (!transA_ && transB_) { + return (*GetOp())(¶ms); + } else if (transA_ && transB_) { + ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); +} +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#else +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#endif + +ONNX_OPERATOR_KERNEL_EX( + GemmFloat8, + kMSDomain, + 1, + kRocmExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) + .TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) + .TypeConstraint("TR", BuildKernelDefConstraints()) + .TypeConstraint("TS", BuildKernelDefConstraints()), + GemmFloat8); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh new file mode 100644 index 0000000000000..571936fc5f038 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#if defined(USE_COMPOSABLE_KERNEL) + +#include "core/providers/rocm/composable_kernel_common.h" + +#include "ck/ck.hpp" +#include "ck/utility/functional3.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#endif + +#if !defined(DISABLE_FLOAT8_TYPES) +#include "core/framework/float8.h" +#endif +#include "core/providers/rocm/tunable/gemm_common.h" + +namespace onnxruntime { +namespace rocm { +namespace tunable { + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +constexpr bool always_false = false; + +template +struct Scale { + constexpr const static bool is_pack2_invocable = true; + constexpr const static bool is_pack4_invocable = true; + + explicit Scale(float scale_value, const float* dev_scale_ptr) : scale_value_{scale_value}, dev_scale_ptr_{dev_scale_ptr} {} + + template + __forceinline__ __host__ __device__ Y fast_type_convert(X x) const { + static_assert(always_false, "not implemented"); + (void)x; + } + + template <> + __forceinline__ __host__ __device__ ck::half_t fast_type_convert(ck::f8_t x) const { + // https://github.com/ROCmSoftwarePlatform/triton/blob/0cc3f8b84a16892396f6e08a04991034d67e32b1/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L220-L233 + constexpr const uint16_t mask = 0x7fff; + constexpr const uint16_t sign_mask = 0x8000; + constexpr const uint16_t exp_compensate = []() { + if constexpr (std::is_same_v) { + return 0x2000; + } else if constexpr (std::is_same_v) { + return 0x1c00; + } + }(); + + uint8_t x_u8 = reinterpret_cast(x); + uint16_t x_u16 = static_cast(x_u8) << 8; + uint16_t exp = (x_u16 & mask) >> 1; + uint16_t y = (x_u16 & sign_mask) | (exp + exp_compensate); + return reinterpret_cast(y); + } + + __forceinline__ __host__ __device__ void operator()(ck::half_t& y, const ck::f8_t& x) const { + float scale = scale_value_ * (*dev_scale_ptr_); + y = ck::type_convert(scale * fast_type_convert(x)); + } + + __forceinline__ __host__ __device__ void operator()(ck::half2_t& ys, const ck::f8x2_t& xs) const { + float scale = scale_value_ * (*dev_scale_ptr_); + constexpr const uint32_t mask = 0x7fff7fff; + constexpr const uint32_t sign_mask = 0x80008000; + constexpr const uint32_t exp_compensate = []() { + if constexpr (std::is_same_v) { + return 0x20002000; + } else if constexpr (std::is_same_v) { + return 0x1c001c00; + } + }(); + + const uchar2& x2_u8 = reinterpret_cast(xs); + uchar4 x{0, x2_u8.x, 0, x2_u8.y}; + uint32_t x_u32 = reinterpret_cast(x); + + uint32_t exp = (x_u32 & mask) >> 1; + uint32_t v = (x_u32 & sign_mask) | (exp + exp_compensate); + ys = scale * reinterpret_cast(v); + } + + __forceinline__ __host__ __device__ void operator()(ck::half4_t& ys, const ck::f8x4_t& xs) const { + float scale = scale_value_ * (*dev_scale_ptr_); + constexpr const uint32_t mask = 0x7fff7fff; + constexpr const uint32_t sign_mask = 0x80008000; + constexpr const uint32_t exp_compensate = []() { + if constexpr (std::is_same_v) { + return 0x20002000; + } else if constexpr (std::is_same_v) { + return 0x1c001c00; + } + }(); + + uint32_t xs_u32 = reinterpret_cast(xs); + uint32_t x_u32_0 = __byte_perm(xs_u32, 0, 0x1504); + uint32_t x_u32_1 = __byte_perm(xs_u32, 0, 0x3726); + uint32_t exp_0 = (x_u32_0 & mask) >> 1; + uint32_t exp_1 = (x_u32_1 & mask) >> 1; + uint32_t v_0 = (x_u32_0 & sign_mask) | (exp_0 + exp_compensate); + uint32_t v_1 = (x_u32_1 & sign_mask) | (exp_1 + exp_compensate); + uint64_t v = v_0 | uint64_t(v_1) << 32; + ys = scale * reinterpret_cast(v); + } + + float scale_value_; + const float* const dev_scale_ptr_; +}; +#endif + +namespace blas { + +template +struct GemmFloat8Params : tunable::OpParams { + std::string Signature() const override { + return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); + } + + rocblas_handle handle; + BlasOp opa; + BlasOp opb; + int64_t m; + int64_t n; + int64_t k; + float scale_a{}; + const float* scale_a_dev{}; + const TA* a; + int64_t lda; + float scale_b{}; + const float* scale_b_dev{}; + const TB* b; + int64_t ldb; + TC* c; + float scale_c{}; + const float* scale_c_dev{}; + int64_t ldc; +}; + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Nop = ck::tensor_operation::element_wise::PassThrough; + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector, Nop, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector, Nop, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, Nop>>>& instances); + +template +auto CreateOp(float scale, const float* dev_scale) { + if constexpr (std::is_same_v) { + return Scale(scale, dev_scale); + } else if constexpr (std::is_same_v) { + return Scale(scale, dev_scale); + } else { + return Nop{}; + } +} + +template +auto GetCKF8SplitKGemmTypeStringAndOps() { + using CKTA = typename CKDataTypeAdaptor::type; + using CKTB = typename CKDataTypeAdaptor::type; + using CKTC = typename CKDataTypeAdaptor::type; + + using CKLayoutA = typename CKBlasOpAdaptor::type; + using CKLayoutB = typename CKBlasOpAdaptor::type; + + using OpA = std::conditional_t, Scale, Nop>; + using OpB = std::conditional_t, Scale, Nop>; + using OpC = std::conditional_t, Scale, Nop>; + + using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK< + CKLayoutA, CKLayoutB, Row, + CKTA, CKTB, CKTC, + OpA, OpB, OpC>; + + std::vector>>> ret; + + for (auto num_split : {1, 4, 16, 64}) { + std::vector> instances{}; + if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) { + add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(instances); + } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) { + add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(instances); + } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) { + add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(instances); + } else { + static_assert(always_false, "no instances for the type combination"); + LOGS_DEFAULT(FATAL) << "no instances for the type combination"; + } + for (auto&& impl : instances) { + auto type_string = std::to_string(ret.size()) + "_" + impl->GetTypeString() + "_SplitK" + std::to_string(num_split); + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemm_op = [num_split, impl = std::move(impl), invoker = std::move(invoker)](const GemmFloat8Params* params) -> Status { + OpA op_a = CreateOp(params->scale_a, params->scale_a_dev); + OpB op_b = CreateOp(params->scale_b, params->scale_b_dev); + OpC op_c = CreateOp(params->scale_c, params->scale_c_dev); + + auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, + params->m, params->n, params->k, + params->lda, params->ldb, params->ldc, + op_a, op_b, op_c, num_split); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support ", params->Signature()); + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); + } + } + return ret; +} + +#endif // USE_COMPOSABLE_KERNEL + +template +class GemmFloat8TunableOp : public TunableOp> { + public: + GemmFloat8TunableOp() { +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#else + ORT_ENFORCE(false, "CK is required to support GemmFloat8 computing"); +#endif // USE_COMPOSABLE_KERNEL + } +}; + +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu new file mode 100644 index 0000000000000..4c691dd18f2e9 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +namespace internal { +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( + std::vector, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( + std::vector, PassThrough>>>& instances); +} // namespace internal + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); +} + +namespace internal { +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough, PassThrough>>>& instances); + +// TODO: The first try of derivation does not going well due to various constraints. +// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( +// std::vector, PassThrough, PassThrough>>>& instances); + +// TODO: The first try of derivation does not going well due to various constraints. +// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( +// std::vector, PassThrough, PassThrough>>>& instances); +} // namespace internal + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector, PassThrough, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); + // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: +} + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector, PassThrough, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); + // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: +} + +namespace internal { +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& instances); +} // namespace internal + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); +} + +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu new file mode 100644 index 0000000000000..49463e58886f8 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> + // clang-format on + >; + +// The derived version is simply double BBlockTransferSrcScalarPerVector and adjust other values correspondingly +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 8, 4, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 8, 4, 32, 32, 3, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 8, 4, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 12, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 16, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 8, 4, 32, 32, 3, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 8, 4, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu new file mode 100644 index 0000000000000..236e5555051fc --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu new file mode 100644 index 0000000000000..1a0d45df82a71 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 2, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F16> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu new file mode 100644 index 0000000000000..a0628802ec09e --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +template +using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); +} + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 0f8fe68de717a..55cd6a1d112f5 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -138,6 +138,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GemmFloat8); #ifdef ENABLE_ATEN class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen); @@ -296,6 +297,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN BuildKernelCreateInfo, diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/nn/conv.cc b/onnxruntime/core/codegen/passes/op_ir_creator/nn/conv.cc index c3a9e5950acce..19545d1554405 100644 --- a/onnxruntime/core/codegen/passes/op_ir_creator/nn/conv.cc +++ b/onnxruntime/core/codegen/passes/op_ir_creator/nn/conv.cc @@ -29,9 +29,9 @@ Status GENERIC_OP_IR_CREATOR_CLASS(Conv)::Evaluate( info.GetAttrOrDefault("group", &group, 1); info.GetAttrOrDefault("auto_pad", &auto_pad, "NOTSET"); - ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttrs("kernel_shape", kernel_shape)); ORT_ENFORCE(kernel_shape.size() <= 2, "Only support 1D/2D convolution currently!"); - ORT_ENFORCE(info.GetAttrs("strides", strides).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttrs("strides", strides)); dilations = info.GetAttrs("dilations", dilations).IsOK() ? dilations : std::vector(kernel_shape.size(), 1); ORT_ENFORCE(dilations == std::vector(kernel_shape.size(), 1), "Only support dilation is 1 currently"); diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/pad.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/pad.cc index ecff2c7b73847..e9e20e8a43998 100644 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/pad.cc +++ b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/pad.cc @@ -23,9 +23,9 @@ Status GENERIC_OP_IR_CREATOR_CLASS(Pad)::Evaluate( std::vector pads; float value; - ORT_ENFORCE(attrs.GetAttr("mode", &mode).IsOK()); - ORT_ENFORCE(attrs.GetAttrs("pads", pads).IsOK()); - ORT_ENFORCE(attrs.GetAttr("value", &value).IsOK()); + ORT_THROW_IF_ERROR(attrs.GetAttr("mode", &mode)); + ORT_THROW_IF_ERROR(attrs.GetAttrs("pads", pads)); + ORT_THROW_IF_ERROR(attrs.GetAttr("value", &value)); if (mode != "constant" && mode != "edge" && mode != "reflect") return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Pad: Unsupported padding mode!"); diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 655d5014f3d60..fcf9c2b03dea5 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -183,7 +183,8 @@ void CPUIDInfo::ArmLinuxInit() { #elif defined(_WIN32) void CPUIDInfo::ArmWindowsInit() { - +// ARM32 certainly doesn't have fp16, so we will skip the logic to avoid using RegGetValueA Windows API +#ifndef _M_ARM #pragma region Application Family or OneCore Family #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) // Read MIDR from windows registry @@ -270,6 +271,9 @@ void CPUIDInfo::ArmWindowsInit() { #endif /* Application Family or OneCore Family */ has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); +#else + has_arm_neon_dot_ = false; +#endif has_fp16_ |= has_arm_neon_dot_; /* TODO: implement them when hw+sw is available for testing these features */ has_arm_neon_i8mm_ = false; diff --git a/onnxruntime/core/common/path_string.h b/onnxruntime/core/common/path_string.h index 76434f5453549..6cfb327cce08a 100644 --- a/onnxruntime/core/common/path_string.h +++ b/onnxruntime/core/common/path_string.h @@ -13,6 +13,15 @@ #include #endif +// for converting / printing ORT_TSTR path strings to std::string +#ifdef _WIN32 +#define ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(X) std::wstring_convert>().to_bytes(X) +#define ORT_TSTR_CONVERT_FROM_STRING(X) std::wstring_convert>().from_bytes(X); +#else +#define ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(X) X +#define ORT_TSTR_CONVERT_FROM_STRING(X) X +#endif + #include "core/common/common.h" #include "core/session/onnxruntime_c_api.h" diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index 6e0eb460d2a63..eca1221e84cb8 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include @@ -37,5 +38,32 @@ inline InlinedVector SplitString(std::string_view string_to_sp return result; } +/** + * Trim a string from start inplace. + * @param s The string to trim. + */ +inline void TrimStringFromLeft(std::string& s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); })); +} + +/** + * Trim a string from end inplace. + * @param s The string to trim. + */ +inline void TrimStringFromRight(std::string& s) { + s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base(), s.end()); +} + +/** + * Trim a string from both ends. + * @param s The string to trim. + * @return The trimmed string. + */ +inline std::string TrimString(std::string s) { + TrimStringFromRight(s); + TrimStringFromLeft(s); + return s; +} + } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 9556e056dedc0..ea7a6432a7507 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1035,8 +1035,11 @@ class PlannerImpl { std::function dfs = [&](NodeIndex curr) { if (dependents.find(curr) == dependents.end()) { dependents.insert(curr); - for (NodeIndex dep : dependence_graph_[curr]) { - dfs(dep); + auto dep_graph_iter = dependence_graph_.find(curr); + if (dep_graph_iter != dependence_graph_.end()) { + for (NodeIndex dep : dep_graph_iter->second) { + dfs(dep); + } } } }; diff --git a/onnxruntime/core/framework/config_options.cc b/onnxruntime/core/framework/config_options.cc index 3b322e1fcd689..1a4acb6dabf71 100644 --- a/onnxruntime/core/framework/config_options.cc +++ b/onnxruntime/core/framework/config_options.cc @@ -52,4 +52,11 @@ Status ConfigOptions::AddConfigEntry(const char* config_key, const char* config_ return Status::OK(); } +std::ostream& operator<<(std::ostream& os, const ConfigOptions& config_options) { + for (const auto& [key, value] : config_options.configurations) { + os << " " << key << ": " << value; + } + return os; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/config_options.h b/onnxruntime/core/framework/config_options.h index 4297819bed111..7b7c226819e79 100644 --- a/onnxruntime/core/framework/config_options.h +++ b/onnxruntime/core/framework/config_options.h @@ -32,6 +32,8 @@ struct ConfigOptions { // Add a config pair (config_key, config_value) to this instance of ConfigOptions Status AddConfigEntry(const char* config_key, const char* config_value) noexcept; + + friend std::ostream& operator<<(std::ostream& os, const ConfigOptions& config_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 7bf11f8293a36..d97953fd9d5ea 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -12,6 +12,9 @@ #include "core/framework/execution_provider.h" #include "core/graph/graph_viewer.h" #include "core/common/logging/logging.h" +#ifdef _WIN32 +#include "core/platform/tracing.h" +#endif namespace onnxruntime { @@ -36,7 +39,19 @@ class ExecutionProviders { ORT_IGNORE_RETURN_VALUE(provider_idx_map_.insert({provider_id, new_provider_idx})); // update execution provider options - exec_provider_options_[provider_id] = p_exec_provider->GetProviderOptions(); + auto providerOptions = p_exec_provider->GetProviderOptions(); + exec_provider_options_[provider_id] = providerOptions; + +#ifdef _WIN32 + for (const auto& config_pair : providerOptions) { + TraceLoggingWrite( + telemetry_provider_handle, + "ProviderOptions", + TraceLoggingString(provider_id.c_str(), "ProviderId"), + TraceLoggingString(config_pair.first.c_str(), "Key"), + TraceLoggingString(config_pair.second.c_str(), "Value")); + } +#endif exec_provider_ids_.push_back(provider_id); exec_providers_.push_back(p_exec_provider); diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc index ea93db58339c7..4f5fa9910b5df 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc +++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc @@ -53,128 +53,200 @@ Status AddLayoutTransformationRequiredOpsToKernelTypeStrResolver(KernelTypeStrRe // clang-format off constexpr uint8_t kLayoutTransformationRequiredOpsKernelTypeStrResolverBytes[] = { 0x10, 0x00, 0x00, 0x00, 0x6b, 0x74, 0x73, 0x72, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, - 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0xbc, 0x06, 0x00, 0x00, - 0x4c, 0x02, 0x00, 0x00, 0xe0, 0x01, 0x00, 0x00, 0xe0, 0x00, 0x00, 0x00, 0x14, 0x06, 0x00, 0x00, - 0x88, 0x01, 0x00, 0x00, 0xb8, 0x05, 0x00, 0x00, 0x1c, 0x05, 0x00, 0x00, 0x18, 0x07, 0x00, 0x00, - 0xcc, 0x04, 0x00, 0x00, 0x0c, 0x01, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x54, 0x05, 0x00, 0x00, - 0x3c, 0x06, 0x00, 0x00, 0xf8, 0x02, 0x00, 0x00, 0x7c, 0x02, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x38, 0x03, 0x00, 0x00, 0xec, 0xf8, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, + 0x4c, 0x0b, 0x00, 0x00, 0xac, 0x08, 0x00, 0x00, 0xd0, 0x0a, 0x00, 0x00, 0x10, 0x06, 0x00, 0x00, + 0xa8, 0x07, 0x00, 0x00, 0x18, 0x03, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, + 0x44, 0x07, 0x00, 0x00, 0x9c, 0x01, 0x00, 0x00, 0xf8, 0x07, 0x00, 0x00, 0x78, 0x09, 0x00, 0x00, + 0x14, 0x01, 0x00, 0x00, 0x50, 0x06, 0x00, 0x00, 0x60, 0x02, 0x00, 0x00, 0xf4, 0x08, 0x00, 0x00, + 0x8c, 0x03, 0x00, 0x00, 0x9c, 0x02, 0x00, 0x00, 0x84, 0x06, 0x00, 0x00, 0xcc, 0x03, 0x00, 0x00, + 0x60, 0x05, 0x00, 0x00, 0xb8, 0x01, 0x00, 0x00, 0x1c, 0x03, 0x00, 0x00, 0x08, 0x04, 0x00, 0x00, + 0xe0, 0x09, 0x00, 0x00, 0x8c, 0xf4, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xf4, 0xff, 0xff, + 0x08, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xda, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xf4, 0xff, 0xff, + 0xd8, 0xf4, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x60, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, + 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, 0x10, 0xf5, 0xff, 0xff, 0xa4, 0x0a, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfc, 0xf4, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x2c, 0xf5, 0xff, 0xff, 0xb0, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x4e, 0xf5, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x48, 0xf5, 0xff, 0xff, 0xc8, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x38, 0xf5, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x30, 0xf5, 0xff, 0xff, 0x6c, 0xf5, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, + 0x31, 0x39, 0x00, 0x00, 0x9c, 0xf5, 0xff, 0xff, 0x3c, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc2, 0xf5, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x94, 0xf5, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xc4, 0xf5, 0xff, 0xff, + 0xe8, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xb4, 0xf5, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xac, 0xf5, 0xff, 0xff, + 0xe8, 0xf5, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, + 0x79, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, 0x10, 0xf6, 0xff, 0xff, 0xac, 0x05, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x36, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xf8, 0xf5, 0xff, 0xff, 0x34, 0xf6, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, + 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, + 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0x74, 0xf6, 0xff, 0xff, + 0x38, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x64, 0xf6, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x5c, 0xf6, 0xff, 0xff, + 0x98, 0xf6, 0xff, 0xff, 0x40, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xbe, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x90, 0xf6, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xc0, 0xf6, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0xe4, 0xf6, 0xff, 0xff, + 0x2c, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x0a, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xcc, 0xf6, 0xff, 0xff, + 0x08, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, + 0x73, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x30, 0xf7, 0xff, 0xff, 0xe0, 0x08, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x56, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x18, 0xf7, 0xff, 0xff, 0x54, 0xf7, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, + 0x78, 0xf7, 0xff, 0xff, 0x98, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9e, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x60, 0xf7, 0xff, 0xff, 0x9c, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x4e, 0x68, 0x77, 0x63, 0x4d, 0x61, - 0x78, 0x50, 0x6f, 0x6f, 0x6c, 0x3a, 0x31, 0x00, 0x20, 0xf9, 0xff, 0xff, 0xf0, 0x06, 0x00, 0x00, + 0x78, 0x50, 0x6f, 0x6f, 0x6c, 0x3a, 0x31, 0x00, 0xd0, 0xf7, 0xff, 0xff, 0x40, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x0e, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x08, 0xf9, 0xff, 0xff, 0x44, 0xf9, 0xff, 0xff, + 0xf6, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xf7, 0xff, 0xff, 0xf4, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x31, - 0x00, 0x00, 0x00, 0x00, 0x6c, 0xf9, 0xff, 0xff, 0xa4, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x5a, 0xf9, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x54, 0xf9, 0xff, 0xff, 0x90, 0xf9, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, 0xb4, 0xf9, 0xff, 0xff, - 0x5c, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xa2, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xf9, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, 0x1c, 0xf8, 0xff, 0xff, 0xf4, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x42, 0xf8, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xf8, 0xff, 0xff, 0x40, 0xf8, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0x00, 0x00, + 0x68, 0xf8, 0xff, 0xff, 0xa8, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8e, 0xf8, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x50, 0xf8, 0xff, 0xff, 0x8c, 0xf8, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0xf4, 0x00, 0x00, 0x00, 0xc8, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, + 0x0c, 0x01, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, + 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, + 0x74, 0x3a, 0x51, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, 0x76, 0x3a, 0x31, 0x00, + 0xd8, 0xf8, 0xff, 0xff, 0xdc, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xc4, 0xf8, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xf4, 0xf8, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x33, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x22, 0xf9, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xf4, 0xf8, 0xff, 0xff, 0x07, 0x00, 0x00, 0x00, 0x24, 0xf9, 0xff, 0xff, + 0xe4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x10, 0xf9, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, 0x40, 0xf9, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x38, 0xf9, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0x68, 0xf9, 0xff, 0xff, 0x70, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xf9, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, + 0x60, 0xf9, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x90, 0xf9, 0xff, 0xff, 0x1c, 0x05, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x80, 0xf9, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x78, 0xf9, 0xff, 0xff, 0xb4, 0xf9, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x34, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa8, 0xf9, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0xd8, 0xf9, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, - 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0xfa, 0xff, 0xff, 0xb4, 0x01, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x48, 0xfa, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x1c, 0xfa, 0xff, 0xff, 0xf4, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0a, 0xfa, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x04, 0xfa, 0xff, 0xff, 0x40, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, - 0x68, 0xfa, 0xff, 0xff, 0x3c, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x56, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x50, 0xfa, 0xff, 0xff, 0x8c, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, - 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0xb4, 0xfa, 0xff, 0xff, - 0x00, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xfc, 0xfa, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xd0, 0xfa, 0xff, 0xff, 0x40, 0x05, 0x00, 0x00, + 0x38, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, + 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x04, 0xfa, 0xff, 0xff, + 0x84, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xf0, 0xf9, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x20, 0xfa, 0xff, 0xff, 0xf0, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xbe, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xfa, 0xff, 0xff, 0xf4, 0xfa, 0xff, 0xff, + 0x46, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x08, 0xfa, 0xff, 0xff, 0x44, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, - 0x31, 0x31, 0x00, 0x00, 0x1c, 0xfb, 0xff, 0xff, 0x98, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x64, 0xfb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x38, 0xfb, 0xff, 0xff, 0xd8, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x26, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x20, 0xfb, 0xff, 0xff, 0x5c, 0xfb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, - 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, - 0x88, 0xfb, 0xff, 0xff, 0x88, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x76, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x70, 0xfb, 0xff, 0xff, 0xac, 0xfb, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x61, 0x78, 0x65, 0x73, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x00, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xd4, 0xfb, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, - 0x31, 0x00, 0x00, 0x00, 0xfc, 0xfb, 0xff, 0xff, 0x14, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xea, 0xfb, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0xe4, 0xfb, 0xff, 0xff, 0x20, 0xfc, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x38, 0x01, 0x00, 0x00, 0xdc, 0x00, 0x00, 0x00, - 0xa8, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, - 0x48, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, - 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, - 0x76, 0x3a, 0x31, 0x00, 0x6c, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x54, 0x34, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xbc, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x90, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x79, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe4, 0xfc, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, - 0xb8, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, - 0x78, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x0c, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xe0, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x33, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd6, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x3c, 0xfd, 0xff, 0xff, 0x07, 0x00, 0x00, 0x00, 0x10, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x32, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x64, 0xfd, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, - 0x6c, 0xfd, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x40, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x94, 0xfd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, - 0x68, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x54, 0x31, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xbc, 0xfd, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x58, 0xfd, 0xff, 0xff, 0x94, 0xfd, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, - 0xb8, 0xfd, 0xff, 0xff, 0x58, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa6, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xa0, 0xfd, 0xff, 0xff, 0xdc, 0xfd, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, - 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0xff, - 0xa0, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xf2, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xec, 0xfd, 0xff, 0xff, - 0x28, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, - 0x73, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x50, 0xfe, 0xff, 0xff, 0xc0, 0x01, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x3e, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xfe, 0xff, 0xff, 0x74, 0xfe, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x36, - 0x00, 0x00, 0x00, 0x00, 0x9c, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x92, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x8c, 0xfe, 0xff, 0xff, - 0xc8, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, - 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, 0xf0, 0xfe, 0xff, 0xff, 0x20, 0x01, 0x00, 0x00, + 0x31, 0x31, 0x00, 0x00, 0x6c, 0xfa, 0xff, 0xff, 0xc4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xfa, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x88, 0xfa, 0xff, 0xff, 0x88, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xae, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x70, 0xfa, 0xff, 0xff, 0xac, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, + 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, 0xd0, 0xfa, 0xff, 0xff, 0x40, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xde, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xd8, 0xfe, 0xff, 0xff, 0x14, 0xff, 0xff, 0xff, + 0xf6, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xfa, 0xff, 0xff, 0xf4, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, - 0x00, 0x00, 0x00, 0x00, 0x3c, 0xff, 0xff, 0xff, 0xd4, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2a, 0xff, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x24, 0xff, 0xff, 0xff, 0x60, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x1c, 0xfb, 0xff, 0xff, 0xf4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x42, 0xfb, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xfb, 0xff, 0xff, 0x40, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, + 0x68, 0xfb, 0xff, 0xff, 0xa8, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8e, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x50, 0xfb, 0xff, 0xff, 0x8c, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x36, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xfb, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe2, 0xfb, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xa4, 0xfb, 0xff, 0xff, 0xe0, 0xfb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, + 0x08, 0xfc, 0xff, 0xff, 0x08, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2e, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xf0, 0xfb, 0xff, 0xff, 0x2c, 0xfc, 0xff, 0xff, 0x04, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x18, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x48, 0xfc, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, + 0x31, 0x30, 0x00, 0x00, 0x7c, 0xfc, 0xff, 0xff, 0x30, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xfc, 0xff, 0xff, 0x94, 0xfc, 0xff, 0xff, + 0x44, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xba, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x8c, 0xfc, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0xbc, 0xfc, 0xff, 0xff, 0x4c, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa8, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xd8, 0xfc, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x4c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, + 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x39, + 0x00, 0x00, 0x00, 0x00, 0x0c, 0xfd, 0xff, 0xff, 0xcc, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x32, 0xfd, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x34, 0xfd, 0xff, 0xff, + 0x78, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x24, 0xfd, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x1c, 0xfd, 0xff, 0xff, + 0x58, 0xfd, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x40, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, + 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x80, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x61, 0x78, 0x65, 0x73, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x78, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xa8, 0xfd, 0xff, 0xff, 0x68, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xce, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x90, 0xfd, 0xff, 0xff, 0xcc, 0xfd, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, + 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x79, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x28, 0xfe, 0xff, 0xff, 0x84, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0xff, 0x40, 0xfe, 0xff, 0xff, 0x98, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x66, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xfe, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x68, 0xfe, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, + 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, + 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0xa4, 0xfe, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x31, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9c, 0xfe, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x94, 0xfe, 0xff, 0xff, 0xd0, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x32, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfe, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xd0, 0xfe, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, - 0x88, 0xff, 0xff, 0xff, 0x88, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x76, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x70, 0xff, 0xff, 0xff, 0xac, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0xdc, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, - 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, + 0x28, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x20, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x50, 0xff, 0xff, 0xff, 0xc0, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x76, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xff, 0xff, 0xff, 0x74, 0xff, 0xff, 0xff, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x84, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, + 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, + 0x00, 0x00, 0x00, 0x00, 0xac, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x78, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xa4, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xd4, 0xff, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, }; // clang-format on diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 8deeb4c2b8b64..40c59cfcf699d 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -5,6 +5,8 @@ #include #include +#include +#include #include "core/common/gsl.h" #include "core/common/inlined_containers.h" #include "core/framework/config_options.h" @@ -24,6 +26,21 @@ enum class ExecutionOrder { PRIORITY_BASED = 1 // priority-based topological sort }; +inline std::ostream& operator<<(std::ostream& os, const ExecutionOrder& order) { + switch (order) { + case ExecutionOrder::DEFAULT: + os << "DEFAULT"; + break; + case ExecutionOrder::PRIORITY_BASED: + os << "PRIORITY_BASED"; + break; + default: + os << "UNKNOWN"; + break; + } + return os; +} + enum class FreeDimensionOverrideType { Invalid = 0, Denotation = 1, @@ -89,6 +106,7 @@ struct SessionOptions { /// Log severity for the inference session. Applies to session load, initialization, etc. /// See https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/common/logging/severity.h + /// See https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_c_api.h#L231 for OrtLoggingLevel mappings /// Default = -1 (use default logger severity) int session_log_severity_level = -1; int session_log_verbosity_level = 0; ///< VLOG level if debug build and session_log_severity_level is 0 (VERBOSE). @@ -154,4 +172,37 @@ struct SessionOptions { void* user_logging_param = nullptr; }; +inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { + os << "Session Options { " + << " execution_mode:" << session_options.execution_mode + << " execution_order:" << session_options.execution_order + << " enable_profiling:" << session_options.enable_profiling + << " optimized_model_filepath:" << ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath) + << " enable_mem_pattern:" << session_options.enable_mem_pattern + << " enable_mem_reuse:" << session_options.enable_mem_reuse + << " enable_cpu_mem_arena:" << session_options.enable_cpu_mem_arena + << " profile_file_prefix:" << ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.profile_file_prefix) + << " session_logid:" << session_options.session_logid + << " session_log_severity_level:" << session_options.session_log_severity_level + << " session_log_verbosity_level:" << session_options.session_log_verbosity_level + << " max_num_graph_transformation_steps:" << session_options.max_num_graph_transformation_steps + << " graph_optimization_level:" << static_cast(session_options.graph_optimization_level) + << " intra_op_param:" << session_options.intra_op_param + << " inter_op_param:" << session_options.inter_op_param + //<< " free_dimension_overrides:" << session_options.free_dimension_overrides + << " use_per_session_threads:" << session_options.use_per_session_threads + << " thread_pool_allow_spinning:" << session_options.thread_pool_allow_spinning + << " use_deterministic_compute:" << session_options.use_deterministic_compute + << " config_options: { " << session_options.config_options << " }" + //<< " initializers_to_share_map:" << session_options.initializers_to_share_map +#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_EXTERNAL_INITIALIZERS) + //<< " external_initializers:" << session_options.external_initializers +#endif +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + //<< " custom_op_libs:" << session_options.custom_op_libs +#endif + << " }"; + return os; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tensor_shape.cc b/onnxruntime/core/framework/tensor_shape.cc index 521f4062c1ff6..399dc1a2a4e69 100644 --- a/onnxruntime/core/framework/tensor_shape.cc +++ b/onnxruntime/core/framework/tensor_shape.cc @@ -63,7 +63,7 @@ int64_t TensorShape::Size() const { int64_t TensorShape::SizeToDimension(size_t dimension) const { const size_t num_dims = values_.size(); ORT_ENFORCE(dimension <= num_dims, - "Invalid dimension of ", dimension, " for SizeFromDimension. Tensor has ", + "Invalid dimension of ", dimension, " for SizeToDimension. Tensor has ", num_dims, " dimensions."); int64_t size = SizeHelper(0, dimension); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index dcde2ddeb8270..ea67218b5c927 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -259,7 +259,6 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2]; updateOutputShape(ctx, 0, output_shape); - return; } else { fail_shape_inference("Missing input 2 (value)"); } @@ -991,7 +990,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( constexpr const char* GroupQueryAttention_ver1_doc = R"DOC( Group Query Self/Cross Attention. -Supports different number of heads for q and kv. +Supports different number of heads for q and kv. Only supports causal or local attention. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1004,10 +1003,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) - // .Attr("left_padding_last_token", - // "Copy last token to last index of buffer. Default is 0; 1 when true.", - // AttributeProto::INT, - // OPTIONAL_VALUE) + .Attr("local_window_size", + "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", + AttributeProto::INT, + static_cast(-1)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size)", @@ -1144,7 +1143,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OPTIONAL_VALUE) .Input(0, "input", - "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", "T") .Input(1, "position_ids", @@ -1160,7 +1159,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Output(0, "output", - "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "tensor with same shape as input.", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 59adfc523c860..4aa43f5de1cd5 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -80,6 +80,60 @@ void RegisterCollectiveOps() { propagateShapeAndTypeFromFirstInput(ctx); }); + ONNX_CONTRIB_OPERATOR_SCHEMA(ShardedMoE) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("activation_type", + "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", + AttributeProto::STRING, + std::string("relu")) + .Attr("k", + "Number of top experts to select from expert pool", + AttributeProto::INT, + static_cast(1)) + .Attr("local_experts_start_index", + "The start index of local experts", + AttributeProto::INT, + static_cast(-1)) + .Input(0, + "input", + "2D input tensor with shape (num_rows, hidden_size) or " + "3D input tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .Input(1, + "router_probs", + "2D input tensor with shape (num_rows, num_experts)", + "T") + .Input(2, + "fc1_experts_weights", + "3D input tensor with shape (local_num_experts, hidden_size, inter_size)", + "T") + .Input(3, + "fc2_experts_weights", + "3D input tensor with shape (local_num_experts, inter_size, hidden_size)", + "T") + .Input(4, + "fc1_experts_bias", + "2D optional input tensor with shape (local_num_experts, inter_size)", + "T", + OpSchema::Optional) + .Input(5, + "fc2_experts_bias", + "2D optional input tensor with shape (num_experts, hidden_size)", + "T", + OpSchema::Optional) + .Output(0, + "output", + "2D input tensor with shape (num_rows, hidden_size) or " + "3D input tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .TypeConstraint("T", + {"tensor(float)", "tensor(float16)"}, + "Constrain input and output types to float or float16 tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateShapeAndTypeFromFirstInput(ctx); + }); + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedMatMul) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index db0b13b0e1d27..54eb43753931a 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3248,7 +3248,7 @@ void RegisterContribSchemas() { "List of tensors for inputs", "T", OpSchema::Variadic, - true, + false, 1, OpSchema::NonDifferentiable) .Output( @@ -3257,7 +3257,7 @@ void RegisterContribSchemas() { "One or more outputs, list of tensors for outputs", "T", OpSchema::Variadic, - true, + false, 1, OpSchema::NonDifferentiable) .TypeConstraint( @@ -3273,11 +3273,7 @@ void RegisterContribSchemas() { "tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types.") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - // Type inference - propagateElemTypeFromInputToOutput(ctx, 0, 0); - }); + "Constrain input and output types."); static const char* BitmaskDropout_ver1_doc = R"DOC( BitmaskDropout takes an input floating-point tensor, an optional input ratio (floating-point scalar) and an optional input training_mode (boolean scalar). @@ -3363,6 +3359,13 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored .Attr("N", "size of each output feature", AttributeProto::INT) .Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT) .Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Attr("accuracy_level", + "The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) " + "(default unset). It is used to control how input A is quantized or downcast internally while " + "doing computation, for example: 0 means input A will not be quantized or downcast while doing " + "computation. 4 means input A can be quantized with the same block_size to int8 internally from " + "type T1.", + AttributeProto::INT, static_cast(0)) .Input(0, "A", "The input tensor, not quantized", "T1") .Input(1, "B", "1-dimensional data blob", "T2") .Input(2, "scales", "quantization scale", "T1") @@ -3431,7 +3434,7 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 .Input(1, "B", "1-dimensional quantized data for weight", "T2") .Input(2, "absmax", "quantization constants", "T1") .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") - .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float/half_float/brain_float tensors.") .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index 03ad95260c0ad..c8960578f9e3d 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -101,6 +101,7 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::functiononnx_func_proto_; return true; } else if (op_) { + auto get_opset_version = [op = op_](Graph* graph) -> std::optional { + if (op->domain() == kOnnxDomain) { + const auto& domain_to_version = graph->DomainToVersionMap(); + const auto iter = domain_to_version.find(kOnnxDomain); + if (iter != domain_to_version.cend()) { + return iter->second; + } + } + return {}; + }; + // Check if this node has a schema defined function proto. if (op_->HasContextDependentFunction()) { NodeProto node_proto; @@ -595,8 +606,13 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot } else input_types.emplace_back(); } + + auto requested_opset_version = get_opset_version(graph_); + if (!requested_opset_version.has_value()) { + requested_opset_version = SinceVersion(); + } ONNX_NAMESPACE::FunctionBodyBuildContextImpl function_body_ctx(node_proto, input_types); - return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto); + return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto, *requested_opset_version); } else if (op_->HasFunction()) { const FunctionProto* function_ptr = nullptr; // We need to get a function-body suitable for the ONNX opset used by the model. @@ -605,17 +621,12 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot // as the default-version, which is incorrect in the case of functions belonging to // non-onnx domains, like MSDOMAIN. - // We use the following as a temporary hack. - function_ptr = op_->GetFunction(SinceVersion(), false); - - // TODO: Switch to following, once ONNX issue is fixed. - // auto& map = graph_->DomainToVersionMap(); - // const auto iter = map.find(kOnnxDomain); - // if (iter != map.end()) { - // function_ptr = op_->GetFunction(iter->second, true); - // } else { - // function_ptr = op_->GetFunction(); - // } + auto requested_opset_version = get_opset_version(graph_); + if (requested_opset_version.has_value()) { + function_ptr = op_->GetFunction(*requested_opset_version, true); + } else { + function_ptr = op_->GetFunction(SinceVersion(), false); + } if (function_ptr != nullptr) { onnx_function_proto = *function_ptr; @@ -4062,7 +4073,9 @@ static void ReassignSubgraphDependentNodeArgs(const InlinedHashMapExists()) { auto hit = name_to_nodearg.find(input_def->Name()); if (hit != name_to_nodearg.cend()) { - input_def = hit->second; + // Make sure we create a local to this subgraph definition + const auto* new_name_arg = hit->second; + input_def = &graph.GetOrCreateNodeArg(new_name_arg->Name(), input_def->TypeAsProto()); } } } @@ -4088,7 +4101,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin Graph& graph_to_inline = *sub_graph; - std::string unique_id{if_node.Name()}; + std::string unique_id{"_if_"}; if (condition_value) { unique_id.append(then_branch); } else { @@ -4107,7 +4120,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin // Reason: there are no explicit inputs to the subgraphs, and the subgraph's // implicit inputs must be covered by the implicit inputs of the If node. InlinedHashMap outer_scope_values; - const auto if_implicit_inputs = if_node.MutableImplicitInputDefs(); + const auto& if_implicit_inputs = if_node.MutableImplicitInputDefs(); outer_scope_values.reserve(if_implicit_inputs.size()); for (auto* input : if_implicit_inputs) { @@ -4121,8 +4134,8 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin // We are going to map the outputs of the graph to inline to the outputs of the If node. // They are assumed to be in the same order. - const auto node_output_defs = if_node.MutableOutputDefs(); - const auto graph_output_defs = graph_to_inline.GetOutputs(); + const auto& node_output_defs = if_node.MutableOutputDefs(); + const auto& graph_output_defs = graph_to_inline.GetOutputs(); for (size_t i = 0; i < graph_output_defs.size(); ++i) { name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]); } @@ -4206,6 +4219,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin } } + auto* non_existing_arg = &GetOrCreateNodeArg(std::string(), nullptr); // We want to make sure we get nodes in topological order // because Constant folding may cause the nodes appear in // a different order. @@ -4216,68 +4230,94 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin auto* node = graph_to_inline.GetNode(node_idx); assert(node->OpType() != kConstant); - InlinedVector new_node_input_defs; - for (const auto* input_def : node->InputDefs()) { + // Inputs + // Chop off trailing non-existing defs, but preserve non-existing in the middle + auto& input_defs = node->MutableInputDefs(); + auto last_existing = std::find_if(input_defs.rbegin(), input_defs.rend(), + [](const NodeArg* node_arg) { return node_arg->Exists(); }); + input_defs.resize(std::distance(input_defs.begin(), last_existing.base())); + + InlinedVector new_input_defs; + for (auto* input_def : node->InputDefs()) { if (input_def->Exists()) { // Check if this is one of the implicit graph inputs - // then leave the name as is and re-use the NodeArg + // then re-assign the def to the outer scope value. const auto& input_name = input_def->Name(); auto outer_hit = outer_scope_values.find(input_name); if (outer_hit != outer_scope_values.cend()) { - new_node_input_defs.push_back(outer_hit->second); + // get/create local definition + NodeArg* outer_arg = outer_hit->second; + auto& this_scope_arg = GetOrCreateNodeArg(outer_arg->Name(), input_def->TypeAsProto()); + new_input_defs.push_back(&this_scope_arg); } else { auto hit = name_to_nodearg.find(input_name); if (hit != name_to_nodearg.cend()) { - // This is other node output, constant node or initializer that was renamed. - new_node_input_defs.push_back(hit->second); + // This is other node output in the dest graph, + // constant node or initializer that was renamed. + new_input_defs.push_back(hit->second); } else { ORT_THROW("Node's: ", node->Name(), " input: ", input_name, " is not If node's input or previous node output in this subgraph"); } } + } else { + new_input_defs.push_back(non_existing_arg); } } - InlinedVector new_node_output_defs; - for (const auto* output_def : node->OutputDefs()) { - const auto& output_name = output_def->Name(); - auto hit = name_to_nodearg.find(output_name); - if (hit != name_to_nodearg.cend()) { - // This is one of the graph outputs, we rename it to - // If node output. - new_node_output_defs.push_back(hit->second); + // Outputs + // Chop off trailing non-existing defs + auto& output_defs = node->MutableOutputDefs(); + last_existing = std::find_if(output_defs.rbegin(), output_defs.rend(), + [](const NodeArg* node_arg) { return node_arg->Exists(); }); + output_defs.resize(std::distance(output_defs.begin(), last_existing.base())); + + InlinedVector new_output_defs; + for (auto* output_def : node->OutputDefs()) { + if (output_def->Exists()) { + const auto& output_name = output_def->Name(); + auto hit = name_to_nodearg.find(output_name); + if (hit != name_to_nodearg.cend()) { + // This is one of the If node outputs, simply reassign the def. + // If node defs are already in the destination graph + new_output_defs.push_back(hit->second); + } else { + // We generate an output to downstream nodes. + auto new_name = GenerateNodeArgName(make_unique(output_name)); + NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto()); + new_output_defs.push_back(&new_arg); + ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg)); + } } else { - // We generate an output to downstream nodes. - auto new_name = GenerateNodeArgName(make_unique(output_name)); - NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto()); - new_node_output_defs.push_back(&new_arg); - ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg)); + new_output_defs.push_back(non_existing_arg); } } const auto new_node_name = GenerateNodeName(make_unique(node->OpType())); Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(), - new_node_input_defs, - new_node_output_defs, + new_input_defs, + new_output_defs, nullptr, node->Domain()); + new_node.SetSinceVersion(node->SinceVersion()); + new_node.op_ = node->op_; + if (!is_this_main_graph) { map_defs(new_node, input_args, true); map_defs(new_node, output_args, false); new_nodes.push_back(&new_node); } - new_node.SetSinceVersion(node->SinceVersion()); - new_node.op_ = node->op_; - if (node->ContainsSubgraph()) { auto& subgraphs = node->MutableSubgraphs(); // Check if any of this node implicit inputs of this graph is in the renaming map + // that would mean they come from the destination graph, not from the parent + // of the destination graph. int renames_subgraph_names = 0; - auto& new_implicit_defs = node->MutableImplicitInputDefs(); - for (auto& input_def : new_implicit_defs) { + auto& implicit_defs = node->MutableImplicitInputDefs(); + for (auto& input_def : implicit_defs) { auto hit = name_to_nodearg.find(input_def->Name()); if (hit != name_to_nodearg.cend()) { input_def = hit->second; @@ -4298,7 +4338,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin new_node.MutableSubgraphs() = std::move(subgraphs); new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph()); - new_node.MutableImplicitInputDefs() = std::move(new_implicit_defs); + new_node.MutableImplicitInputDefs() = std::move(implicit_defs); } new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes()); diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 5482a8e286da5..cf78040ea5ac6 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -35,6 +35,17 @@ struct PriorityNodeCompare { return n1->Priority() > n2->Priority(); } + // nodes of forward pass will be output first + auto n1_attrs = n1->GetAttributes(); + auto n2_attrs = n2->GetAttributes(); + int64_t n1_is_forward = static_cast(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) || + (n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; + int64_t n2_is_forward = static_cast(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) || + (n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; + if (n1_is_forward != n2_is_forward) { + return n2_is_forward > n1_is_forward; + } + // otherwise, nodes with lower index will be output first return n1->Index() > n2->Index(); } @@ -57,6 +68,14 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) : ConstGraphNodes::NodeFilterFunc(nullptr))}, filter_info_{filter_info} { std::vector leaf_nodes; +#ifdef ENABLE_TRAINING + // Keep the info of shape and size nodes and their parents so that after topological sort, we can move them + // right after their parents. This is to make sure the shape and size nodes are executed right after their parents + // so it's possible the input tensor memory can be released as soon as possible. This is especially important + // for non-CPU devices or for training case where some gradient graphs use only shape/size of tensors from forward. + InlinedHashSet shape_size_nodes; + InlinedHashMap> shape_size_parents; +#endif for (auto& node : graph_->Nodes()) { // This is a leaf node (without any output node) if (node.OutputNodesBegin() == node.OutputNodesEnd()) { @@ -66,6 +85,17 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) if (node.InputEdgesBegin() == node.InputEdgesEnd()) { root_nodes_.push_back(node.Index()); } +#ifdef ENABLE_TRAINING + if ((node.OpType() == "Shape" || node.OpType() == "Size") && node.InputEdgesBegin() != node.InputEdgesEnd()) { + shape_size_nodes.insert(node.Index()); + NodeIndex parent = node.InputNodesBegin()->Index(); + if (shape_size_parents.find(parent) == shape_size_parents.end()) { + shape_size_parents[parent] = InlinedVector{node.Index()}; + } else { + shape_size_parents[parent].push_back(node.Index()); + } + } +#endif } graph.ReverseDFSFrom( @@ -75,7 +105,24 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) nodes_in_topological_order_.push_back(n->Index()); }, NodeCompare()); - +#ifdef ENABLE_TRAINING + auto original = std::move(nodes_in_topological_order_); + nodes_in_topological_order_.reserve(original.size()); + InlinedHashSet visited; + for (auto& node : original) { + if (visited.find(node) != visited.end()) { + continue; + } + nodes_in_topological_order_.push_back(node); + visited.insert(node); + if (shape_size_parents.find(node) != shape_size_parents.end()) { + for (auto& following_node : shape_size_parents[node]) { + nodes_in_topological_order_.push_back(following_node); + visited.insert(following_node); + } + } + } +#endif #if !defined(ORT_MINIMAL_BUILD) graph.KahnsTopologicalSort( [this](const Node* n) { diff --git a/onnxruntime/core/mickey/README.md b/onnxruntime/core/mickey/README.md new file mode 100644 index 0000000000000..7e8d30cd1805b --- /dev/null +++ b/onnxruntime/core/mickey/README.md @@ -0,0 +1,6 @@ +# About Mickey + +Playful name for a template library of high performance cuda code that +are often shared by various AI operators. The intention is to make this +header files only, with no binary impact unless it is instantiated +where it is needed. diff --git a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/prepack_sm80.h new file mode 100644 index 0000000000000..e291ab39e8aa3 --- /dev/null +++ b/onnxruntime/core/mickey/blk_q4/prepack_sm80.h @@ -0,0 +1,325 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * prepack_sm80.h + * + * Abstract: + * Prepack weights and quantization parameters (scales and offsets) for + * GEMM, where activations are fp16 or bf16, and weights are block-wise + * 4b quantized values, specifically for Ampere GPUs. + * + * Prepacking enables faster loading of weights and quantization parameters + * into tensor cores, and faster dequantization of weights. + * + * Only supports fp16 for now, bfloat16 support will be added later. + */ + +#pragma once + +#include "core/common/common.h" +#include "core/util/matrix_layout.h" + +namespace onnxruntime { +namespace cuda { + +/** + * @brief Blockwise quantization methods + * @tparam ElementT source data type, fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int block_size, + int qbits, + bool Columnwise, + bool ExtraBoundsCheck = false> +struct BlockwiseQuantization { + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + static_assert(sizeof(ElementT) == 2, "Only 16b floating point types are supported!"); + + using QuantBlocking = + std::conditional_t, + MatrixShape<1, block_size>>; + + using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them + // We pack 4 weights into one 16b element, so we can leverage cutlass tile iterators + // for async share memory loading, and minimizing bank conflict during matrix loading + using ElementWPack = ElementT; + using LayoutWPack = ColumnMajorLayout; // <- layout of packed weight, must be column major + + // Current Ampere kernel use 8b zero point, need to shrink it to 4b in the future + using ElementQOffset = uint8_t; + + // Layout of the quantization parameters (scales and zero points) + // Major on the dimension that has the most parameters per squarish weight block. + // E.g. for column-wise quantization, a [64, 64] block has [2, 64] parameters, + // where each row has more data, so we use row major layout so that warp threads + // can use less load instructions to load more parameters. + using LayoutQmeta = + typename std::conditional::type; + + /** + * @brief Get quantized weight tensor dimensions. + * Actual weight type is int4, we use ElementW = uint8 to avoid possible compilation + * troubles. Since the layout is column major, we are packing 2 weights in a column + * into one int8 + */ + static inline auto get_quant_weights_shape(int rows, int columns) { + return make_Position(rows / 2, columns); + } + + static inline auto get_quant_meta_shape(int rows, int columns) { + return make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); + } + + /** + * @brief Prepack weight matrix to facilitate matrix loading, depending on MMA + * instruction layout. + * + * The weight matrix is int4, yet we want to leverage existing fp16/bf16 + * tile loading and MMA layout code in CUTLASS. So we group 4 int4 into 2 + * bytes, pretending it's fp16. This grouping must be done in a way to be + * easily unpacked into tiles that match the MMA instruction layout. + * For MMA instruction <16, 8, 16>, each instruction processes 2 8x8 tiles, + * vertically stacked on the K dimension. And MmaTensorOpMultiplicandTileIterator + * loads a tile. + * + * So we stack 2x2 tiles on a 3rd dimeansion, and reshape them in a HWC fashion: + * T0, T2 + * T1, T3 + * ==> + * T0[0, 0], T1[0, 0], T2[0, 0], T3[0, 0] + * T0[1, 0], T1[1, 0], T2[1, 0], T3[1, 0] + * T0[2, 0], T1[2, 0], T2[2, 0], T3[2, 0] + * T0[3, 0], T1[3, 0], T2[3, 0], T3[3, 0] + * ... + * T0[0, 7], T1[0, 7], T2[0, 7], T3[0, 7] + * T0[1, 7], T1[1, 7], T2[1, 7], T3[1, 7] + * T0[2, 7], T1[2, 7], T2[2, 7], T3[2, 7] + * T0[3, 7], T1[3, 7], T2[3, 7], T3[3, 7] + * + * This pack a 8x16 int8 tile into a 16x8 int8 tile, i.e. a 8x8 16b tile + */ + static void prepack_weights( + int rows, + int columns, + const gsl::span& weights, // <- int4 weights, column major + const gsl::span& weights_prepacked // <- int4 prepacked weights tensor, same size buffer + ) { + ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0 && + (rows % QuantBlocking::kRow) == 0 && + (columns % QuantBlocking::kColumn) == 0, + "Does not support odd number of rows or columns!"); + ORT_ENFORCE(weights.size() == size_t(rows * columns / 2), + "Weight tensor shape mismatch!"); + ORT_ENFORCE(weights_prepacked.size() == weights.size(), + "Prepacked Weight tensor buffer should be the same size!"); + + const MatrixRef + tensor_weight(weights, make_Position(rows / 2, columns)); + const MatrixRef + tensor_weight_prepacked(weights_prepacked, make_Position(rows, columns / 2)); + + // TODO(fuchen)!! parallized this. + auto t0_base = make_Position(0, 0); + auto t1_base = make_Position(4, 0); + auto t2_base = make_Position(0, 8); + auto t3_base = make_Position(4, 8); + for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { + for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { + // Packing from a 8x16 tile to a 16x8 tile + auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); + auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto cord = make_Position(row, col); + auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 + uint8_t buf[4]; + buf[0] = tensor_weight.at(dtile_base + t0_base + cord); + buf[1] = tensor_weight.at(dtile_base + t1_base + cord); + buf[2] = tensor_weight.at(dtile_base + t2_base + cord); + buf[3] = tensor_weight.at(dtile_base + t3_base + cord); + + // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights + // are in different b16 register at the same positions. This makes it easier to convert to + // fp16x2 format in a b32 register + + tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); + tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); + } + } + } + } + } + + /** + * @brief We rearrange the values of the quantization scale and offset tensors + * to facilitate faster loading to tensor core, only 16b gemm, and (1,n) + * block quantization. + */ + static constexpr bool ShouldRearrangeMeta = sizeof(ElementT) == 2 && QuantBlocking::kRow == 1; + + static void prepack_quant_scales( + size_t rows, + size_t columns, + const gsl::span& scales, // <- quant scales, column major layout + const gsl::span& scales_prepacked // <- quant scales prepacked, same size buffer + ) { + auto meta_shape = get_quant_meta_shape(rows, columns); + ORT_ENFORCE(scales.size() == size_t(meta_shape.product()), + "Quantization scale tensor shape mismatch!"); + ORT_ENFORCE(scales_prepacked.size() == size_t(meta_shape.product()), + "Prepacked quantization scale tensor buffer should be the same size!"); + + MatrixRef tensor_scale(scales, meta_shape); + MatrixRef tensor_scale_prepacked(scales_prepacked, meta_shape); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ElementT) == 2 && QuantBlocking::kRow == 1) { + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); + } + } + } + } else { + // In all other cases, we don't prepack scale or offset + // Potential transpose if the prepacked layout is different from the original layout + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row = 0; row < tensor_scale.shape()[0]; ++row) { + tensor_scale_prepacked.at(row, col) = tensor_scale.at(row, col); + } + } + } + } + + static void prepack_quant_offsets( + size_t rows, + size_t columns, + const gsl::span& offsets, // <- quant offsets, int4, column major layout + const gsl::span& offsets_prepacked // <- quant offsets prepacked, double size buffer + ) { + auto meta_shape = get_quant_meta_shape(rows, columns); + + ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0, + "Does not support odd number of rows or columns!"); + ORT_ENFORCE(offsets_prepacked.size() == size_t(meta_shape.product()), + "Wrong buffer size for prepacked quantization offsets!"); + ORT_ENFORCE(offsets.size() == size_t(((meta_shape[0] + 1) / 2) * meta_shape[1]), + "Quantization offset tensor shape mismatch!"); + + MatrixRef + tensor_offset(offsets, make_Position((meta_shape[0] + 1) / 2, meta_shape[1])); + MatrixRef tensor_offset_prepacked(offsets_prepacked, meta_shape); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ElementT) == 2 && QuantBlocking::kRow == 1) { + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row_blk = 0; row_blk < meta_shape[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + uint8_t pair01 = tensor_offset.at(src_idx / 2, col); + uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); + tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; + tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; + tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; + tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; + } + } + } + } else { + // In all other cases, we don't prepack scale or offset + // Potential transpose if the prepacked layout is different from the original layout + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + uint8_t pair01 = tensor_offset.at(row / 2, col); + tensor_offset_prepacked.at(row + 0, col) = pair01 & 0xf; + if (row + 1 < meta_shape[0]) { + tensor_offset_prepacked.at(row + 1, col) = pair01 >> 4; + } + } + } + } + } +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index fd6b3df93444b..bdd4dba521eba 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -69,6 +69,9 @@ Module Name: #endif #endif +#if defined(__loongarch64) +#define MLAS_TARGET_LARCH64 +#endif // // Define the support levels for the target architecture. // @@ -87,7 +90,7 @@ Module Name: #define MLAS_F16VEC_INTRINSICS_SUPPORTED -#endif // +#endif // #endif // ARM64 #endif // Visual Studio 16 or earlier does not support fp16 intrinsic @@ -1619,7 +1622,7 @@ MlasHalfGemmConvertPackB( * @param Channels # of input channels * @param OutputCount # of output pixels * @param KernelSize # kernel size - * @return + * @return */ void MLASCALL @@ -1657,7 +1660,7 @@ MlasTranspose( * @param Channels C in NHWC * @param OutputCount Number of output pixels * @param KernelSize Size of the kernel - * @return + * @return */ void MLASCALL @@ -1676,7 +1679,7 @@ MlasNhwcMaxPool( * @param Channels C in NHWC * @param OutputCount Number of output pixels * @param KernelSize size of the kernel - * @return + * @return */ void MLASCALL diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 9620dd42d1da9..1e83dd1cec400 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -77,3 +77,144 @@ MlasIsSQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen ); + +/** + * @brief Define compute types of block quantization + */ +typedef enum { + CompUndef = 0, /*!< undef */ + CompFp32 = 1, /*!< input fp32, accumulator fp32 */ + CompFp16 = 2, /*!< input fp16, accumulator fp16 */ + CompBf16 = 3, /*!< input bf16, accumulator fp32 */ + CompInt8 = 4 /*!< input int8, accumulator int32 */ +} MLAS_SQNBIT_COMPUTE_TYPE; + +/** + * @brief Data parameters for NBits GEMM routine + * C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * All except C are [in] parameters + */ +struct MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS { + const float* A = nullptr; /**< address of A (float32 matrix)*/ + const void* B = nullptr; /**< address of B (packed nbits blob)*/ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldc = 0; /**< leading dimension of C*/ +}; + +/** + * @brief Compute the byte size of the parameter combination + * + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @return size of the packing buffer, 0 if the operation is not yet supported. + */ +size_t MLASCALL +MlasNBitsGemmPackBSize( + size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE comp_type +); + +/** + * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. + * + * @param PackedBuf packed data buffer + * @param QData quantized data buffer + * @param Scale scale pointer + * @param Zp zero point pointer + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization (default 4) + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor + * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where + * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up + * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale + * (is_asym is false) and Zp(is_asym is true). + * @param thread_pool + */ +void MLASCALL +MlasNBitsGemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t block_size, + int nbits, + bool is_asym, + bool last_call, + MLAS_SQNBIT_COMPUTE_TYPE comp_type, + MLAS_THREADPOOL* thread_pool +); + +/** + * @brief Unpack and dequantize to fp32 + * + * @param FpData unpacked float32 data + * @param PackedBuf quantized and packed data + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param thread_pool + */ +void MLASCALL +MlasNBitsGemmUnPackB( + float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool +); + +/** + * @brief Get the workspace size required by computation. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @return Workspace size in bytes + */ +size_t MLASCALL +MlasSQNBitsGemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +); + +/** + * @brief Batched GEMM: C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] WorkSpace temporary buffer + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasSQNBitsGemmBatchPackedB( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + void* WorkSpace, + MLAS_THREADPOOL* ThreadPool = nullptr +); diff --git a/onnxruntime/core/mlas/lib/activate.cpp b/onnxruntime/core/mlas/lib/activate.cpp index 6c4ab8ae118dc..df3b884a7e7c9 100644 --- a/onnxruntime/core/mlas/lib/activate.cpp +++ b/onnxruntime/core/mlas/lib/activate.cpp @@ -143,6 +143,8 @@ struct MLAS_ACTIVATION_FUNCTION return MlasBlendFloat32x4(ValueTimesAlpha, Value, _mm_cmple_ps(ZeroFloat32x4, Value)); #elif defined(MLAS_VSX_INTRINSICS) return vec_sel(ValueTimesAlpha, Value, vec_cmple(ZeroFloat32x4, Value)); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasBlendFloat32x4(ValueTimesAlpha, Value, (__m128)__lsx_vfcmp_cle_s(ZeroFloat32x4, Value)); #else return MlasBlendFloat32x4(ValueTimesAlpha, Value, ZeroFloat32x4 < Value); #endif diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp index 118351055157d..78cac2e617ff7 100644 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ b/onnxruntime/core/mlas/lib/compute.cpp @@ -148,6 +148,9 @@ Return Value: // instead. normal = _mm_min_epi16(normal, MaximumExponent); normal = _mm_max_epi16(normal, MinimumExponent); +#elif defined(MLAS_LSX_INTRINSICS) + normal = __lsx_vmin_h(normal, MaximumExponent); + normal = __lsx_vmax_h(normal, MinimumExponent); #else normal = MlasMinimumInt32x4(normal, MaximumExponent); normal = MlasMaximumInt32x4(normal, MinimumExponent); @@ -215,6 +218,8 @@ Return Value: // N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle // and use zeroes for the upper elements. Vector = _mm_load_ss(Input); +#elif defined(MLAS_LSX_INTRINSICS) + Vector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input, 0); #else Vector = MlasBroadcastFloat32x4(Input); #endif @@ -467,6 +472,8 @@ Return Value: // N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle and // use zeroes for the upper elements. MLAS_FLOAT32X4 Vector = _mm_load_ss(Input); +#elif defined(MLAS_LSX_INTRINSICS) + MLAS_FLOAT32X4 Vector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input, 0); #else MLAS_FLOAT32X4 Vector = MlasBroadcastFloat32x4(Input); #endif @@ -849,7 +856,7 @@ Return Value: // Find the maximum value for the row. // -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) float Maximum = GetMlasPlatform().ReduceMaximumF32Kernel(Input, D); #else float Maximum = MlasReduceMaximumF32Kernel(Input, D); @@ -874,7 +881,7 @@ Return Value: float Parameters[] = { NegativeMaximum, std::log(Accumulation)}; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) GetMlasPlatform().ComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); #else MlasComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); @@ -899,7 +906,7 @@ Return Value: float Parameters[] = { 1.0f / Accumulation }; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) GetMlasPlatform().ComputeSoftmaxOutputF32Kernel(Output, D, Parameters); #else MlasComputeSoftmaxOutputF32Kernel(Output, D, Parameters); diff --git a/onnxruntime/core/mlas/lib/dgemm.cpp b/onnxruntime/core/mlas/lib/dgemm.cpp index 1ef63d03c8014..50c62744f1d8e 100644 --- a/onnxruntime/core/mlas/lib/dgemm.cpp +++ b/onnxruntime/core/mlas/lib/dgemm.cpp @@ -530,7 +530,7 @@ Return Value: size_t RowsHandled; -#if defined(MLAS_TARGET_AMD64_IX86) || defined (MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) RowsHandled = GetMlasPlatform().GemmDoubleKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { diff --git a/onnxruntime/core/mlas/lib/jblas_defs.h b/onnxruntime/core/mlas/lib/jblas_defs.h new file mode 100644 index 0000000000000..9cd1711a3ffd2 --- /dev/null +++ b/onnxruntime/core/mlas/lib/jblas_defs.h @@ -0,0 +1,73 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +--*/ + +#pragma once + +#include "jblas/jit_blas_prologue_b.h" +#include "jblas/jit_blas_wrapper.h" + +namespace jblas +{ + +/* +Name conversion explaination: +Fp32: comp type, determined by GemmCore, can be any jblas::gemm::SCorexxx(float GemmCore) +S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(also support other integer and float weight +classes) +F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and +jblas::epilogue::gemm::AccumulatorWriteBackFp32. + +Tips: jblas::epilogue::gemm::CompFp32BlockEpilogue is a fixed class for all fp32 accumulator GemmCores. +*/ +template +using tLauncher_Fp32_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< + GemmCore_T::ISA, + GemmCore_T, + jblas::prologue_a::gemm::ActivationKBlockBaseF32, + jblas::prologue_b::gemm::WeightKBlockS4, + jblas::epilogue::gemm::CompFp32BlockEpilogue, + jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + +/* +Name conversion explaination: +Int8: comp type, determined by GemmCore, can be any jblas::gemm::ICorexxx(integer GemmCore) +S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(support integer weight classes only) +F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and +jblas::epilogue::gemm::AccumulatorWriteBackFp32. + +Tips: jblas::epilogue::gemm::CompInt8BlockEpilogue is a fixed class for all int32 accumulator GemmCores. +*/ +template +using tLauncher_Int8_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< + GemmCore_T::ISA, + GemmCore_T, + jblas::prologue_a::gemm::ActivationF32KBlockQuantize, + jblas::prologue_b::gemm::WeightKBlockS4, + jblas::epilogue::gemm::CompInt8BlockEpilogue, + jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + +using tAVX512F = jblas::gemm::SCoreRowNAvx512f<48, 8>; +using tAMX_BF16 = jblas::gemm::HCoreRowNAmxbf16<64, 16>; +using tAVX512_FP16 = jblas::gemm::HCoreRowNAvx512fp16<96, 8>; +using tAVX_VNNI = jblas::gemm::ICoreRowNAvxvnni<48, 2>; // TODO(Yu) use 24x4 for higher efficiency +using tAVX512_VNNI = jblas::gemm::ICoreRowNAvx512vnni<48, 8>; +using tAMX_INT8_US = jblas::gemm::ICoreRowNAmxint8<64, 16>; +using tAMX_INT8_SS = jblas::gemm::ICoreRowNAmxint8SS<64, 16>; +using tAVX2 = jblas::gemm::SCoreRowNAvx2<48, 2>; // TODO(Yu) use 24x4 for higher efficiency + +class ORTThreading : public jblas::parallel::IThreading +{ + public: + ORTThreading(void* tp); + void parallel_for(const jblas::parallel::thread_func& func) override; + void set_threads(int nthreads) override { assert(0); } + void sync() override { assert(0); } + void* mTp; +}; + +} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.cpp b/onnxruntime/core/mlas/lib/jblas_gemm.cpp new file mode 100644 index 0000000000000..f3cae3186c28e --- /dev/null +++ b/onnxruntime/core/mlas/lib/jblas_gemm.cpp @@ -0,0 +1,534 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + jblas_gemm.cpp + +Abstract: + + Currently only support Q4 gemm. +--*/ + +#include "jblas_gemm.h" + +#include "jblas_defs.h" +#include "mlasi.h" + +using namespace jblas; + +jblas::ORTThreading::ORTThreading(void* tp) + : IThreading(MLAS_THREADPOOL::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) +{ +} + +void +jblas::ORTThreading::parallel_for(const jblas::parallel::thread_func& func) +{ + MlasTrySimpleParallel(reinterpret_cast(mTp), mThreadNum, [&](ptrdiff_t tid) { + func(static_cast(tid)); + }); +} + +template +static void +JblasSQ4GemmCompF32( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + jblas::storage::gemm::StorageWeightKBlockS4* B, + float* C, + const size_t ldc, + int8_t* WorkSpace, + jblas::parallel::IThreading* th +) +{ + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + if (M <= 16) { + using Parallel = jblas::parallel::gemm::SchedulerKBlock; + using Launcher = tLauncher_Fp32_S4_F32F32; + static Launcher kernel; + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + if (B->mIsAsym) { + reduceA.assign(WorkSpace); + ORTThreading single(nullptr); + kernel.mProA.reduce({A, lda_}, &reduceA, M_, K_, &single); + } + typename Launcher::BEpiParam blkargs{ + B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(), + reduceA.template get(), reduceA.lda}; + + typename Launcher::Param args{M_, N_, K_, B->mBlockSize, {A, lda_}, {B}, blkargs, {C, ldc_}}; + jblas::parallel::GemmKBlockRun(kernel, args, th); + } else { + using Parallel = jblas::parallel::gemm::SchedulerBase; + using Launcher = jblas::wrapper::gemm::LauncherBase< + GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, + jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + static Launcher kernel; + + typename Launcher::Param args{M_, N_, K_, {A, lda_}, {B}, {C, ldc_}}; + jblas::parallel::GemmBaseRun(kernel, args, th); + } +} + +template +static void +JblasSQ4GemmCompInt8( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + jblas::storage::gemm::StorageWeightKBlockS4* B, + float* C, + const size_t ldc, + int8_t* WorkSpace, + jblas::parallel::IThreading* th +) +{ + using Parallel = jblas::parallel::gemm::SchedulerKBlock; + using Launcher = tLauncher_Int8_S4_F32F32; + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + static Launcher kernel; + auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->mIsAsym); + quanA.assign(WorkSpace); + if (M <= 16) { + ORTThreading single(nullptr); + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); + } else { + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); + } + typename Launcher::Param args{ + M_, + N_, + K_, + B->mBlockSize, + {A, lda_, &quanA}, + {B}, + {B->template SPtr(), B->mScaT, B->mCStep, quanA.template SPtr(), quanA.mCStep, + quanA.template ZPtr(), B->template RPtr(), B->mRedT, B->template ZPtr(), + quanA.template RPtr(), B->mBlockSize}, + {C, ldc_}}; + jblas::parallel::GemmKBlockRun(kernel, args, th); +} + +bool +JblasSQ4GemmBatchDriver( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + int8_t* WorkSpace, + MLAS_THREADPOOL* ThreadPool +) +{ + GetCPUDevice(); + ORTThreading orth(ThreadPool); + bool processed = true; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { + auto kptr = reinterpret_cast(ptr); + auto coretype = ptr->mCoreId; + auto NTile = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT + ); + auto CType = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT + ); + if (CType == uint32_t(gemm::CompType::COMP_FP32)) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + JblasSQ4GemmCompF32( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { + JblasSQ4GemmCompF32( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { + JblasSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { + JblasSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { + JblasSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { + if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { + JblasSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } + } + } + } else { + processed = false; + break; + } + } + return processed; +} + +template +static size_t +JblasSQ4GemmCompF32WorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + jblas::storage::gemm::StorageWeightKBlockS4* B, + float* C, + const size_t ldc +) +{ + auto M_ = static_cast(M); + auto K_ = static_cast(K); + (void)(N); + (void)(lda); + (void)(ldc); + if (M <= 16) { + using Launcher = tLauncher_Fp32_S4_F32F32; + static Launcher kernel; + if (B->mIsAsym) { + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + return reduceA.mSize; + } + return 0; + } else { + using Launcher = jblas::wrapper::gemm::LauncherBase< + GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, + jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + static Launcher kernel; + return 0; + } + return 0; +} + +template +static size_t +JblasSQ4GemmCompInt8WorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + jblas::storage::gemm::StorageWeightKBlockS4* B, + float* C, + const size_t ldc +) +{ + using Parallel = jblas::parallel::gemm::SchedulerKBlock; + using Launcher = tLauncher_Int8_S4_F32F32; + static Launcher kernel; + (void)(N); + (void)(lda); + (void)(ldc); + auto quanA = kernel.mProA.createStorage( + static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->mIsAsym + ); + return quanA.mSize; +} + +size_t +JblasSQ4GemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +) +{ + GetCPUDevice(); + size_t size = 0; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { + auto kptr = reinterpret_cast(ptr); + auto coretype = ptr->mCoreId; + auto NTile = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT + ); + auto CType = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT + ); + if (CType == uint32_t(gemm::CompType::COMP_FP32)) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + size = std::max( + JblasSQ4GemmCompF32WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { + size = std::max( + JblasSQ4GemmCompF32WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { + size = std::max( + JblasSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { + size = std::max( + JblasSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { + size = std::max( + JblasSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { + if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { + size = std::max( + JblasSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } + } + } + } + } + return size; +} + +template +static size_t +JblasQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) +{ + static T launcher; + auto stor = launcher.mProB.createStorage( + static_cast(N), static_cast(K), static_cast(block_size), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, + JBLAS_DTYPE::BF16, isAsym + ); + // TODO(Yu) support more scale dtype + return stor.mSize; +} + +size_t +JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType) +{ + GetCPUDevice(); + if (K % BlkSize != 0) { + return 0; + } + // from low precision to high precision + switch (CompType) { + case CompInt8: + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + case CompBf16: + case CompFp16: + case CompFp32: + case CompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + break; + default: + return 0; + } + return 0; +} + +template +static void +JblasQ4GemmPackBImpl( + void* PackedBuf, + size_t BlkSize, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + bool IsAsym, + bool lastCall, + size_t ldb, + MLAS_THREADPOOL* ThreadPool +) +{ + static T JblasKernel; + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto stor = JblasKernel.mProB.createStorage( + N_, K_, static_cast(BlkSize), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, IsAsym + ); + stor.assign(reinterpret_cast(PackedBuf)); + ORTThreading orth(ThreadPool); + JblasKernel.mProB.packNbitsWeight(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); + if (lastCall) { + JblasKernel.mProB.reduceWeight(&stor, &orth); + } +} + +bool +JblasQ4GemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + bool isAsym, + bool lastCall, + MLAS_SQNBIT_COMPUTE_TYPE CompType, + MLAS_THREADPOOL* ThreadPool +) +{ + GetCPUDevice(); + // explicit statement fall through. + switch (CompType) { + case CompInt8: + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + case CompBf16: + case CompFp16: + case CompFp32: + case CompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + default: + return false; + } + return false; +} + +bool +JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) +{ + auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); + auto uptr = std::unique_ptr(ptr); + ORTThreading orth(ThreadPool); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto ldb_ = static_cast(ldb); + GetCPUDevice(); + if (ptr) { + if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { + auto NTile = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT + ); + auto CType = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT + ); + if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } + } + if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } + } + if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_SS_INT32)) { + if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } + } + } + return true; + } + return false; +} diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.h b/onnxruntime/core/mlas/lib/jblas_gemm.h new file mode 100644 index 0000000000000..044dc5e849a0a --- /dev/null +++ b/onnxruntime/core/mlas/lib/jblas_gemm.h @@ -0,0 +1,61 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + jblas_gemm.h + +Abstract: + + Currently only support Q4 gemm. +--*/ + +#pragma once + +#include "mlas_qnbit.h" + +size_t +JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType); + +bool +JblasQ4GemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + bool isAsym, + bool lastCall, + MLAS_SQNBIT_COMPUTE_TYPE CompType, + MLAS_THREADPOOL* ThreadPool +); + +bool +JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb + , MLAS_THREADPOOL* ThreadPool); + +bool +JblasSQ4GemmBatchDriver( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + int8_t* WorkSpace, + MLAS_THREADPOOL* ThreadPool +); + +size_t +JblasSQ4GemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +); diff --git a/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelCommon.h b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelCommon.h new file mode 100644 index 0000000000000..8d812baabdf9d --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelCommon.h @@ -0,0 +1,27 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernelCommon.h + +Abstract: + + This module contains common kernel macros and structures for the double + precision matrix/matrix multiply operation (DGEMM). + +--*/ + +#define LFgemmElementShift 3 +#define LFgemmElementSize (1 << LFgemmElementShift) +#define LFgemmYmmElementCount (32/LFgemmElementSize) + +#include "FgemmKernelCommon.h" + +FGEMM_TYPED_INSTRUCTION(xvfadd, xvfadd.d) +FGEMM_TYPED_INSTRUCTION(xvfmadd, xvfmadd.d) +FGEMM_TYPED_INSTRUCTION(xvldrepl, xvldrepl.d) +FGEMM_TYPED_INSTRUCTION(xvfmul, xvfmul.d) diff --git a/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLasx.S new file mode 100644 index 0000000000000..2f197d6891579 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLasx.S @@ -0,0 +1,32 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernelLasx.s + +Abstract: + + This module implements the kernels for the double precision matrix/matrix + multiply operation (DGEMM). + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" +#include "DgemmKernelCommon.h" +#include "FgemmKernelLasxCommon.h" + + .text + +// +// Generate the GEMM kernel. +// + +FgemmKernelLasxFunction MlasGemmDoubleKernelLasx + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLsx.S b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLsx.S new file mode 100644 index 0000000000000..63395631a9bc5 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLsx.S @@ -0,0 +1,217 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernelLsx.s + +Abstract: + + This module implements the kernels for the double precision matrix/matrix + multiply operation (DGEMM). + + This implementation uses Lsx instructions. + +--*/ + +#include "asmmacro.h" +#include "FgemmKernelLsxCommon.h" + +FGEMM_TYPED_INSTRUCTION(vfadd, vfadd.d) +/*++ + +Macro Description: + + This macro multiplies and accumulates for a 8xN block of the output matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + +Implicit Arguments: + + a1 (rsi) - Supplies the address into the matrix B data. + + vr0-vr1 - Supplies up to two elements loaded from matrix A and matrix A + plus one row. + + vr8-vr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockSseBy8 RowCount + + vld $vr4, $a1, 0 + vld $vr5, $a1, 16 +.if \RowCount\() == 2 + vmove $vr6, $vr4 + vmove $vr7, $vr5 +.endif + vfmadd.d $vr8, $vr4, $vr0, $vr8 + vfmadd.d $vr9, $vr5, $vr0, $vr9 +.if \RowCount\() == 2 + vfmadd.d $vr12, $vr6, $vr1, $vr12 + vfmadd.d $vr13, $vr7, $vr1, $vr13 +.endif + vld $vr4, $a1, 32 + vld $vr5, $a1, 48 +.if \RowCount\() == 2 + vmove $vr6, $vr4 + vmove $vr7, $vr5 +.endif + vfmadd.d $vr10, $vr4, $vr0, $vr10 + vfmadd.d $vr11, $vr5, $vr0, $vr11 +.if \RowCount\() == 2 + vfmadd.d $vr14, $vr6, $vr1, $vr14 + vfmadd.d $vr15, $vr7, $vr1, $vr15 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + + Fallthrough - Supplies a non-blank value if the macro may fall through to + the ExitKernel label. + +Implicit Arguments: + + a0 - Supplies the address of matrix A. + + a1 - Supplies the address of matrix B. + + t8 - Supplies the address of matrix A. + + a5 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + a2 - Supplies the address of matrix C. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + t7 - Supplies the length in bytes of a row from matrix A. + + t5 - Supplies the length in bytes of a row from matrix C. + + s3 - Stores the ZeroMode argument from the stack frame. + +--*/ + + .macro ProcessCountM RowCount, Fallthrough +.LProcessNextColumnLoop8xN\@: + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr8,$vr8,$vr8" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr9,$vr9,$vr9" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr10,$vr10,$vr10" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr11,$vr11,$vr11" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr12,$vr12,$vr12" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr13,$vr13,$vr13" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr14,$vr14,$vr14" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr15,$vr15,$vr15" + move $t7,$a3 # reload CountK +.LCompute8xNBlockBy1Loop\@: + EmitIfCountGE \RowCount\(), 1, "ld.d $s0, $a0, 0" + EmitIfCountGE \RowCount\(), 1, "vreplgr2vr.d $vr0, $s0" + EmitIfCountGE \RowCount\(), 2, "ldx.d $s0, $a0, $t0" + EmitIfCountGE \RowCount\(), 2, "vreplgr2vr.d $vr1, $s0" + ComputeBlockSseBy8 \RowCount\() + addi.d $a1, $a1, 8*8 # advance matrix B by 8 columns + addi.d $a0, $a0, 8 # advance matrix A by 1 column + addi.d $t7, $t7, -1 + bnez $t7, .LCompute8xNBlockBy1Loop\@ + +.LOutput8xNBlock\@: + movfr2gr.d $s0, $f24 + vreplgr2vr.d $vr2, $s0 + # multiply by alpha + EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr8, $vr8, $vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr9, $vr9, $vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr10,$vr10, $vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr11,$vr11, $vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr12,$vr12, $vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr13,$vr13, $vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr14,$vr14, $vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr15,$vr15, $vr2" + li.d $s0, 8 + blt $a5, $s0, .LOutputPartial8xNBlock\@ + sub.d $a5, $a5, $s0 + AccumulateAndStoreBlock \RowCount\(), 4 + addi.d $a2, $a2, 8*8 # advance matrix C by 8 columns + move $a0, $t1 # reload matrix A + bnez $a5, .LProcessNextColumnLoop8xN\@ + b .LExitKernel + +// +// Output a partial 8xN block to the matrix. +// + +.LOutputPartial8xNBlock\@: + li.d $s0, 2 + blt $a5, $s0, .LOutputPartial1xNBlock\@ + li.d $s0, 4 + blt $a5, $s0, .LOutputPartialLessThan4xNBlock\@ + li.d $s0, 6 + blt $a5, $s0, .LOutputPartialLessThan6xNBlock\@ + AccumulateAndStoreBlock \RowCount\(), 3 + andi $s0, $a5, 1 # check if remaining count is small + beqz $s0, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8,$vr11" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12,$vr15" + addi.d $a2, $a2, 6*8 # advance matrix C by 6 columns + b .LOutputPartial1xNBlock\@ + +.LOutputPartialLessThan6xNBlock\@: + AccumulateAndStoreBlock \RowCount\(), 2 + andi $s0, $a5,1 # check if remaining count is small + beqz $s0, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8,$vr10" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12,$vr14" + addi.d $a2, $a2, 4*8 # advance matrix C by 4 columns + b .LOutputPartial1xNBlock\@ + +.LOutputPartialLessThan4xNBlock\@: + AccumulateAndStoreBlock \RowCount\(), 1 + andi $s0, $a5,1 # check if remaining count is small + beqz $s0, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8,$vr9" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12,$vr13" + addi.d $a2, $a2, 2*8 # advance matrix C by 2 columns + +.LOutputPartial1xNBlock\@: + bnez $t5, .LSkipAccumulateOutput1xN\@ # ZeroMode? + + EmitIfCountGE \RowCount\(), 1, "fld.d $f15, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "fadd.d $f15, $f15, $f8" + EmitIfCountGE \RowCount\(), 2, "fldx.d $f16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "fadd.d $f16, $f16, $f12" + +.LSkipAccumulateOutput1xN\@: + EmitIfCountGE \RowCount\(), 1, "fst.d $f15, $a2, 0" + EmitIfCountGE \RowCount\(), 2, "fstx.d $f16, $a2, $t6" +.ifb \Fallthrough\() + b .LExitKernel +.endif + + .endm + +// +// Generate the GEMM kernel. +// + +FgemmKernelLsxFunction MlasGemmDoubleKernelLSX + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelCommon.h b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelCommon.h new file mode 100644 index 0000000000000..777a592590ec4 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelCommon.h @@ -0,0 +1,100 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + FgemmKernelCommon.h + +Abstract: + + This module contains common kernel macros and structures for the floating + point matrix/matrix multiply operation (SGEMM and DGEMM). + +--*/ + +// +// Define the typed instruction template. +// + +#define FGEMM_TYPED_INSTRUCTION(Untyped, Typed) \ + .macro Untyped Operand:vararg; Typed \Operand\(); .endm; + +/*++ + +Macro Description: + + This macro generates code to execute the block compute macro multiple + times and advancing the matrix A and matrix B data pointers. + +Arguments: + + ComputeBlock - Supplies the macro to compute a single block. + + RowCount - Supplies the number of rows to process. + + AdvanceMatrixAPlusRows - Supplies a non-zero value if the data pointer + in rbx should also be advanced as part of the loop. + +Implicit Arguments: + + a0 - Supplies the address into the matrix A data. + + t7 - Supplies the address into the matrix A data plus 3 rows. + + a1 - Supplies the address into the matrix B data. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + vr4-vr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLoop ComputeBlock, RowCount, AdvanceMatrixAPlusRows + + move $t8, $a3 # reload CountK + li.d $s0, 4 + blt $t8, $s0, .LProcessRemainingBlocks\@ + +.LComputeBlockBy4Loop\@: + \ComputeBlock\() \RowCount\(), 0, LFgemmElementSize*0, 64*4 + \ComputeBlock\() \RowCount\(), 2*32, LFgemmElementSize*1, 64*4 + addi.d $a1, $a1, 2*2*32 # advance matrix B by 128 bytes + \ComputeBlock\() \RowCount\(), 0, LFgemmElementSize*2, 64*4 + \ComputeBlock\() \RowCount\(), 2*32, LFgemmElementSize*3, 64*4 + addi.d $a1, $a1, 2*2*32 # advance matrix B by 128 bytes + addi.d $a0, $a0, 4*LFgemmElementSize # advance matrix A by 4 elements +.if \RowCount\() > 3 + addi.d $t7, $t7, 4*LFgemmElementSize # advance matrix A plus rows by 4 elements +.if \RowCount\() == 12 + addi.d $t3, $t3, 4*LFgemmElementSize + addi.d $t4,, $t4, 4*LFgemmElementSize +.endif +.endif + addi.d $t8, $t8, -4 + li.d $s0, 4 + bge $t8, $s0, .LComputeBlockBy4Loop\@ + +.LProcessRemainingBlocks\@: + beqz $t8, .LOutputBlock\@ + +.LComputeBlockBy1Loop\@: + \ComputeBlock\() \RowCount\(), 0, 0 + addi.d $a1, $a1, 2*32 # advance matrix B by 64 bytes + addi.d $a0, $a0, LFgemmElementSize # advance matrix A by 1 element +.if \RowCount\() > 3 + addi.d $t7, $t7, LFgemmElementSize # advance matrix A plus rows by 1 element +.if \RowCount\() == 12 + addi.d $t3, $t3, LFgemmElementSize + addi.d $t4, $t4, LFgemmElementSize +.endif +.endif + addi.d $t8, $t8, -1 + bnez $t8, .LComputeBlockBy1Loop\@ + +.LOutputBlock\@: + + .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLasxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLasxCommon.h new file mode 100644 index 0000000000000..b96db848617bf --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLasxCommon.h @@ -0,0 +1,546 @@ + +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + FgemmKernelLasxCommon.h + +Abstract: + + This module implements the kernels for the floating point matrix/matrix + multiply operation (SGEMM and DGEMM). + + This implementation uses LASX instructions. + +--*/ + +/*++ + +Macro Description: + + This macro multiplies and accumulates for 2 YMMWORDs by N rows of the output + matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + + PrefetchOffset - Optionally supplies the byte offset from matrix B to + prefetch elements. + +Implicit Arguments: + + a0 - Supplies the address into the matrix A data. + + t7 - Supplies the address into the matrix A data plus 2 rows. + + a1 - Supplies the address into the matrix B data. + + t0 - Supplies the length in bytes of a row from matrix A. + + xr8-xr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLasxBy16 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset + +.if \RowCount\() == 1 + xvldrepl.w $xr3, $a0, \BroadcastOffset\() + xvld $xr4, $a1, \VectorOffset\() + xvfmadd $xr8, $xr4, $xr3, $xr8 + xvld $xr5, $a1, \VectorOffset\()+32 + xvfmadd $xr9, $xr5, $xr3, $xr9 +.else + xvld $xr0, $a1, \VectorOffset\() + xvld $xr1, $a1, \VectorOffset\()+32 + EmitIfCountGE \RowCount\(), 1, "xvldrepl $xr3,$a0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 1, "xvfmadd $xr8, $xr3, $xr0, $xr8" + EmitIfCountGE \RowCount\(), 1, "xvfmadd $xr9, $xr3, $xr1, $xr9" + EmitIfCountGE \RowCount\(), 2, "add.d $s0,$a0, $t0" + EmitIfCountGE \RowCount\(), 2, "xvldrepl $xr3,$s0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 2, "xvfmadd $xr10, $xr3, $xr0, $xr10" + EmitIfCountGE \RowCount\(), 2, "xvfmadd $xr11, $xr3, $xr1, $xr11" + + EmitIfCountGE \RowCount\(), 3, "xvldrepl $xr3,$t7, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 3, "xvfmadd $xr12, $xr3, $xr0, $xr12" + EmitIfCountGE \RowCount\(), 3, "xvfmadd $xr13, $xr3, $xr1, $xr13" + EmitIfCountGE \RowCount\(), 4, "add.d $s0,$t7, $t0" + EmitIfCountGE \RowCount\(), 4, "xvldrepl $xr3,$s0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 4, "xvfmadd $xr14, $xr3, $xr0, $xr14" + EmitIfCountGE \RowCount\(), 4, "xvfmadd $xr15, $xr3, $xr1, $xr15" +.endif + + .endm + +/*++ + +Macro Description: + + This macro multiplies and accumulates for 1 YMMWORD by N rows of the output + matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + + PrefetchOffset - Optionally supplies the byte offset from matrix B to + prefetch elements. + +Implicit Arguments: + + a0 - Supplies the address into the matrix A data. + + t7 - Supplies the address into the matrix A data plus 2 rows. + + a1 - Supplies the address into the matrix B data. + + t0 - Supplies the length in bytes of a row from matrix A. + + xr8-xr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLasxBy8 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset + +.if \RowCount\() == 1 + xvldrepl.w $xr3, $a0, \BroadcastOffset\() + xvld $xr5, $a1, \VectorOffset\() + xvfmadd.s $xr9, $xr5, $xr3, $xr9 +.else + xvld $xr0, $a1, \VectorOffset\() + EmitIfCountGE \RowCount\(), 1, "xvldrepl $xr3, $a0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 1, "xvfmadd $xr9, $xr3, $xr0, $xr9" + + EmitIfCountGE \RowCount\(), 2, "add.d $s0, $a0, $t0" + EmitIfCountGE \RowCount\(), 2, "xvldrepl $xr3, $s0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 2, "xvfmadd $xr11, $xr3, $xr0, $xr11" + EmitIfCountGE \RowCount\(), 3, "xvldrepl $xr3, $t7, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 3, "xvfmadd $xr13, $xr3, $xr0, $xr13" + EmitIfCountGE \RowCount\(), 4, "add.d $s0, $t7, $t0" + EmitIfCountGE \RowCount\(), 4, "xvldrepl $xr3, $s0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 4, "xvfmadd $xr15, $xr3, $xr0, $xr15" +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to execute the block compute macro multiple + times and advancing the matrix A and matrix B data pointers. + +Arguments: + + ComputeBlock - Supplies the macro to compute a single block. + + RowCount - Supplies the number of rows to process. + +Implicit Arguments: + + a0 - Supplies the address into the matrix A data. + + a1 - Supplies the address into the matrix B data. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + t0 - Supplies the length in bytes of a row from matrix A. + + vr4-vr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLasxLoop ComputeBlock, RowCount + +.if \RowCount\() > 2 + # compute matrix A plus 2 rows + slli.d $s0, $t0, 1 + add.d $t7, $a0, $s0 +.endif + ComputeBlockLoop \ComputeBlock\(), \RowCount\(), \RowCount\() > 2 +.if \RowCount\() > 2 + # compute matrix C plus 2 rows + slli.d $s0, $t6, 1 + add.d $t7, $a2, $s0 +.endif + + .endm + + .macro store_n src, num, dst + move $s2, \num\() + beqz $s2, .Lstore_exit\@ + xvstelm.w \src\(), \dst\(), 0, 0 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 4, 1 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 8, 2 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 12, 3 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 16, 4 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 20, 5 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 24, 6 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + +.Lstore_exit\@: + .endm +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + + Fallthrough - Supplies a non-blank value if the macro may fall through to + the ExitKernel label. + +Implicit Arguments: + + a0 - Supplies the address of matrix A. + + a1 - Supplies the address of matrix B. + + t1 - Supplies the address of matrix A. + + a5 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + a2 - Supplies the address of matrix C. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + t0 - Supplies the length in bytes of a row from matrix A. + + t6 - Supplies the length in bytes of a row from matrix C. + + t5 - Stores the ZeroMode argument from the stack frame. + +--*/ + + .macro ProcessCountM RowCount, Fallthrough + + ori $s1, $r0, LFgemmYmmElementCount + bgeu $s1, $a5, .LProcessRemainingCountN\@ + +.LProcessNextColumnLoop2xN\@: + EmitIfCountGE \RowCount\(), 1, "xvxor.v $xr8, $xr8, $xr8" + EmitIfCountGE \RowCount\(), 1, "xvxor.v $xr9, $xr9, $xr9" + EmitIfCountGE \RowCount\(), 2, "xvxor.v $xr10, $xr10, $xr10" + EmitIfCountGE \RowCount\(), 2, "xvxor.v $xr11, $xr11, $xr11" + EmitIfCountGE \RowCount\(), 3, "xvxor.v $xr12, $xr12, $xr12" + EmitIfCountGE \RowCount\(), 3, "xvxor.v $xr13, $xr13, $xr13" + EmitIfCountGE \RowCount\(), 4, "xvxor.v $xr14, $xr14, $xr14" + EmitIfCountGE \RowCount\(), 4, "xvxor.v $xr15, $xr15, $xr15" + + ComputeBlockLasxLoop ComputeBlockLasxBy16, \RowCount\() + EmitIfCountGE \RowCount\(), 1, "xvfmul $xr8, $xr8, $xr2" + EmitIfCountGE \RowCount\(), 1, "xvfmul $xr9, $xr9, $xr2" + EmitIfCountGE \RowCount\(), 2, "xvfmul $xr10, $xr10, $xr2" + EmitIfCountGE \RowCount\(), 2, "xvfmul $xr11, $xr11, $xr2" + EmitIfCountGE \RowCount\(), 3, "xvfmul $xr12, $xr12, $xr2" + EmitIfCountGE \RowCount\(), 3, "xvfmul $xr13, $xr13, $xr2" + EmitIfCountGE \RowCount\(), 4, "xvfmul $xr14, $xr14, $xr2" + EmitIfCountGE \RowCount\(), 4, "xvfmul $xr15, $xr15, $xr2" + + sub.d $a5, $a5, $s1 + sub.d $a5, $a5, $s1 + blt $a5, $zero, .LOutputMasked2xNBlock\@ + andi $s0, $t5, 0xff # ZeroMode? + bnez $s0, .LStore2xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr8, $xr8, $xr16" + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0x20" + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr9, $xr9, $xr16" + EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr10, $xr10, $xr16" + EmitIfCountGE \RowCount\(), 2, "add.d $s0, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvld $xr16, $s0, 0x20" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr11, $xr11, $xr16" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr12, $xr12, $xr16" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0x20" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr13, $xr13, $xr16" + EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr14, $xr14, $xr16" + EmitIfCountGE \RowCount\(), 4, "add.d $s0, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvld $xr16, $s0, 0x20" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr15, $xr15, $xr16" + +.LStore2xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "xvst $xr8, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvst $xr9, $a2, 0x20" + EmitIfCountGE \RowCount\(), 2, "xvstx $xr10, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "add.d $s0, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvst $xr11, $s0, 0x20" + EmitIfCountGE \RowCount\(), 3, "xvst $xr12, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvst $xr13, $t7, 0x20" + EmitIfCountGE \RowCount\(), 4, "xvstx $xr14, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "add.d $s0, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvst $xr15, $s0, 0x20" + + addi.d $a2, $a2, 0x40 # advance matrix C by 2 XRWORDs + move $a0, $t1 # reload matrix A + bltu $s1, $a5, .LProcessNextColumnLoop2xN\@ + beqz $a5, .LExitKernel + +.LProcessRemainingCountN\@: + EmitIfCountGE \RowCount\(), 1, "xvxor.v $xr9, $xr9, $xr9" + EmitIfCountGE \RowCount\(), 2, "xvxor.v $xr11, $xr11, $xr11" + EmitIfCountGE \RowCount\(), 3, "xvxor.v $xr13, $xr13, $xr13" + EmitIfCountGE \RowCount\(), 4, "xvxor.v $xr15, $xr15, $xr15" + + + ComputeBlockLasxLoop ComputeBlockLasxBy8, \RowCount\() + EmitIfCountGE \RowCount\(), 1, "xvfmul $xr9, $xr9, $xr2" + EmitIfCountGE \RowCount\(), 2, "xvfmul $xr11, $xr11, $xr2" + EmitIfCountGE \RowCount\(), 3, "xvfmul $xr13, $xr13, $xr2" + EmitIfCountGE \RowCount\(), 4, "xvfmul $xr15, $xr15, $xr2" + bltu $a5, $s1, .LOutputMasked1xNBlock\@ + andi $s0, $t5, 0xff # ZeroMode? + bnez $s0, .LStore1xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr9, $xr9, $xr16" + EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr11, $xr11, $xr16" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr13, $xr13, $xr16" + EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr15, $xr15, $xr16" + +.LStore1xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "xvst $xr9, $a2, 0" + EmitIfCountGE \RowCount\(), 2, "xvstx $xr11, $a2, $t6" + EmitIfCountGE \RowCount\(), 3, "xvst $xr13, $t7, 0" + EmitIfCountGE \RowCount\(), 4, "xvstx $xr15, $t7, $t6" + b .LExitKernel + +.LOutputMasked2xNBlock\@: + andi $s0, $t5, 0xff # ZeroMode? + bnez $s0, .LStoreMasked2xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr8, $xr8, $xr16" + EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr10, $xr10, $xr16" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr12, $xr12, $xr16" + EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr14, $xr14, $xr16" + +.LStoreMasked2xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "xvst $xr8, $a2, 0" + EmitIfCountGE \RowCount\(), 2, "xvstx $xr10, $a2, $t6" + EmitIfCountGE \RowCount\(), 3, "xvst $xr12, $t7, 0" + EmitIfCountGE \RowCount\(), 4, "xvstx $xr14, $t7, $t6" + addi.d $a2, $a2, 0x20 # advance matrix C by YMMWORD +.if \RowCount\() > 2 + addi.d $t7, $t7, 0x20 # advance matrix C plus 2 rows by YMMWORD + +.endif + addi.d $a5, $a5, LFgemmYmmElementCount # correct for over-subtract above + + +.LOutputMasked1xNBlock\@: + +.if \RowCount\() > 2 + slli.d $s0, $t0, 1 + add.d $t7, $a0, $s0 +.endif + +.if \RowCount\() == 1 +.else +.endif + +.if \RowCount\() > 2 + slli.d $s0, $t6, 1 + add.d $t7, $a2, $s0 +.endif + + sub.d $a5, $zero, $a5 + la.global $a0, MlasMaskMoveTableLasx + ori $s0, $r0, LFgemmElementSize + mul.d $s0, $a5, $s0 + addi.d $s0, $s0, 8*4 + xvldx $xr0, $a0, $s0 + andi $s0, $t5, 0xff + + sub.d $a5, $zero, $a5 + + bnez $s0, .LStoreMasked1xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvand.v $xr8, $xr16, $xr0" + EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvand.v $xr10, $xr16, $xr0" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvand.v $xr12, $xr16, $xr0" + EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvand.v $xr14, $xr16, $xr0" + + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr9, $xr9, $xr8" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr11, $xr11, $xr10" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr13, $xr13, $xr12" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr15, $xr15, $xr14" +.LStoreMasked1xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "store_n $xr9, $a5, $a2" + + add.d $s3, $a2, $t6 + EmitIfCountGE \RowCount\(), 2, "store_n $xr11, $a5, $s3" + + EmitIfCountGE \RowCount\(), 3, "store_n $xr13, $a5, $t7" + + add.d $s3, $t7, $t6 + EmitIfCountGE \RowCount\(), 4, "store_n $xr15, $a5, $s3" + sub.d $a5, $zero, $a5 +.ifb \Fallthrough\() + b .LExitKernel +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates the inner kernel to compute matrix multiplication. + +Arguments: + + FunctionName - Supplies the name for the generated function. + +--*/ + + .macro FgemmKernelLasxFunction FunctionName + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A a0 - Supplies the address of matrix A. + + B a1 - Supplies the address of matrix B. The matrix data has been packed + using MlasSgemmCopyPackB or MlasSgemmTransposePackB. + + C a2 - Supplies the address of matrix C. + + CountK a3 - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM a4 - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN a5 - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda a6 - Supplies the first dimension of matrix A. + + ldc a7 - Supplies the first dimension of matrix C. + + Alpha f0 - Supplies the scalar alpha multiplier (see GEMM definition). + + ZeroMode (sp + 0)- Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + + FUNCTION_ENTRY \FunctionName\() + + addi.d $sp, $sp, -64 + st.d $ra, $sp, 56 + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + fst.s $f0, $sp, 2*8 + fst.d $f16, $sp,3*8 + st.d $s2, $sp, 4*8 + st.d $s3, $sp, 5*8 + + move $t1, $a0 + slli.d $t0, $a6, 2 # convert lda to bytes + slli.d $t6, $a7, 2 # convert ldc to bytes + ld.d $t5, $sp, 64 # get zeromode + fst.s $f0, $sp, 2*8 + xvldrepl.w $xr2, $sp, 0x10 + +// +// Process 4 rows of the matrices. +// + + ori $s0, $zero, 4 + bltu $a4, $s0, .LProcessCountMLessThan4 + li.d $a4, 4 # return 4 rows handled + ProcessCountM 4, Fallthrough + +// +// Restore non-volatile registers and return. +// + +.LExitKernel: + bstrpick.d $a0, $a4, 31, 0 + ld.d $s0, $sp, 0 + ld.d $s1, $sp, 8 + fld.d $f16, $sp,3*8 + ld.d $s2, $sp, 4*8 + ld.d $s3, $sp, 5*8 + ld.d $ra, $sp, 7*8 + addi.d $sp, $sp, 64 + jr $ra + +// +// Process 2 rows of the matrices. +// + +.LProcessCountMLessThan4: + ori $s0, $r0, 2 + bltu $a4, $s0, .LProcessCountMLessThan2 + li.d $a4, 2 # return 2 rows handled + ProcessCountM 2 + +// +// Process 1 row of the matrices. +// + +.LProcessCountMLessThan2: + ProcessCountM 1 + + .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLsxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLsxCommon.h new file mode 100644 index 0000000000000..0333af792ba70 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLsxCommon.h @@ -0,0 +1,170 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + FgemmKernelLsxCommon.h + +Abstract: + + This module implements the kernels for the floating point matrix/matrix + multiply operation (SGEMM and DGEMM). + + This implementation uses Lsx instructions. + +--*/ + +#include "FgemmKernelCommon.h" +/*++ + +Macro Description: + + This stores the block accumulators to the output matrix with an optional + accumulation of the existing contents of the output matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + + VectorCount - Supplies the number of vector columns to process. + +Implicit Arguments: + + t5 - Supplies the length in bytes of a row from matrix C. + + a2 - Supplies the address of matrix C. + + s3 - Stores the ZeroMode argument from the stack frame. + + vr8-vr15 - Supplies the block accumulators. + +--*/ + + .macro AccumulateAndStoreBlock RowCount, VectorCount + + and $s0, $t5,$t5 # ZeroMode? + bnez $s0 , .LSkipAccumulateOutput\@ + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "vld $vr0, $a2, 0" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "vld $vr1, $a2, 16" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "vld $vr2, $a2, 32" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "vld $vr3, $a2, 48" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "vldx $vr4, $a2, $t6" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "addi.d $s0, $t6, 16" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "vldx $vr5, $a2, $s0" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "addi.d $s0, $t6, 32" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "vldx $vr6, $a2, $s0" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "addi.d $s0, $t6, 48" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "vldx $vr7, $a2, $s0" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "vfadd $vr8, $vr8, $vr0" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "vfadd $vr9, $vr9, $vr1" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "vfadd $vr10,$vr10,$vr2" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "vfadd $vr11,$vr11,$vr3" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "vfadd $vr12,$vr12,$vr4" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "vfadd $vr13,$vr13,$vr5" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "vfadd $vr14,$vr14,$vr6" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "vfadd $vr15,$vr15,$vr7" + +.LSkipAccumulateOutput\@: + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "vst $vr8, $a2, 0" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "vst $vr9, $a2, 16" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "vst $vr10, $a2, 32" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "vst $vr11, $a2, 48" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "vstx $vr12, $a2, $t6" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "addi.d $s0, $t6, 16" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "vstx $vr13, $a2, $s0" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "addi.d $s0, $t6, 32" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "vstx $vr14, $a2, $s0" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "addi.d $s0, $t6, 48" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "vstx $vr15, $a2, $s0" + + .endm +/*++ + +Macro Description: + + This macro generates the inner kernel to compute matrix multiplication. + +Arguments: + + FunctionName - Supplies the name for the generated function. + +--*/ + + .macro FgemmKernelLsxFunction FunctionName + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (a0) - Supplies the address of matrix A. + + B (a1) - Supplies the address of matrix B. The matrix data has been packed + using MlasSgemmCopyPackB or MlasSgemmTransposePackB. + + C (a2) - Supplies the address of matrix C. + + CountK (a3) - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM (a4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (a5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda (a6) Supplies the first dimension of matrix A. + + ldc (a7) Supplies the first dimension of matrix C. + + Alpha (f0) - Supplies the scalar alpha multiplier (see GEMM definition). + + ZeroMode (sp 0) - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + +FUNCTION_ENTRY \FunctionName\() + addi.d $sp, $sp, -64 + st.d $t5, $sp, 0 + st.d $s0, $sp, 1*8 + st.d $s1, $sp, 2*8 + st.d $s2, $sp, 3*8 + st.d $s3, $sp, 4*8 + move $t1, $a0 + slli.d $t0, $a6, 2 //convert lda to bytes + slli.d $t6, $a7, 2 //convert ldc to bytes + ld.d $t5, $sp, 64 + fmov.s $f24, $f0 //f0 destroyed by lsx + + li.d $s0, 2 + blt $a4, $s0, .LProcessCountM1 + + li.d $a4, 2 + ProcessCountM 2, Fallthrough + +.LExitKernel: + ld.d $t5, $sp, 0 + ld.d $s0, $sp, 1*8 + ld.d $s1, $sp, 2*8 + ld.d $s2, $sp, 3*8 + ld.d $s3, $sp, 4*8 + addi.d $sp, $sp, 64 + move $a0, $a4 + jr $ra + +.LProcessCountM1: + ProcessCountM 1 + .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasx.S new file mode 100644 index 0000000000000..e03503521912a --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasx.S @@ -0,0 +1,412 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SconvKernelLasx.S + +Abstract: + + This module implements the kernels for the single precision convolution + operation. + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" +#include "SconvKernelLasxCommon.h" + + .text + +/*++ + +Macro Description: + + This macro multiplies and accumulates for FilterCount by OutputCount block + of the output buffer. + +Arguments: + + KernelType - Supplies the type of kernel to be generated. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + + VectorOffset - Supplies the byte offset from the filter buffer to fetch + elements. + + BroadcastOffset - Supplies the byte offset from the input buffer to fetch + elements. + +Implicit Arguments: + + a3 - Supplies the address of the input buffer. + + a2 - Supplies the address of the filter buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + t7 - Supplies the address of the filter buffer plus 2 * FilterStride. + + a5 - Supplies the StrideWidth parameter (see function description). + + xr0-xr7 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlock KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset + +.ifeqs "\KernelType\()","Depthwise" + xvld $xr12, $a2, 0 + EmitIfCountGE \OutputCount\(), 1, "xvld $xr8, $a3, 0" + EmitIfCountGE \OutputCount\(), 1, "xvfmadd.s $xr0, $xr8, $xr12, $xr0" + EmitIfCountGE \OutputCount\(), 2, "xvldx $xr9, $a3, $a5" + EmitIfCountGE \OutputCount\(), 2, "xvfmadd.s $xr4, $xr9, $xr12, $xr4" + +.else + EmitIfCountGE \OutputCount\(), 1, "xvldrepl.w $xr13, $a3, \BroadcastOffset\()" + EmitIfCountGE \OutputCount\(), 2, "add.d $s0, $a3, $a5" + EmitIfCountGE \OutputCount\(), 2, "xvldrepl.w $xr14, $s0, \BroadcastOffset\()" +.if \OutputCount\() == 1 + EmitIfCountGE \FilterCount\(), 1, "xvld $xr8, $a2, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 1, "xvfmadd.s $xr0, $xr8, $xr13, $xr0" + EmitIfCountGE \FilterCount\(), 2, "add.d $s0, $a2, $a1" + EmitIfCountGE \FilterCount\(), 2, "xvld $xr9, $s0, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 2, "xvfmadd.s $xr1, $xr9, $xr13, $xr1" + EmitIfCountGE \FilterCount\(), 3, "xvld $xr10, $t7, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 3, "xvfmadd.s $xr2, $xr10, $xr13, $xr2" + EmitIfCountGE \FilterCount\(), 4, "add.d $s0, $t7, $a1" + EmitIfCountGE \FilterCount\(), 4, "xvld $xr11, $s0, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 4, "xvfmadd.s $xr3, $xr11, $xr13, $xr3" +.else + EmitIfCountGE \FilterCount\(), 1, "xvld $xr12, $a2, \VectorOffset\()" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfmadd.s $xr0, $xr12, $xr13, $xr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfmadd.s $xr4, $xr12, $xr14, $xr4" + EmitIfCountGE \FilterCount\(), 2, "add.d $s0, $a2, $a1" + EmitIfCountGE \FilterCount\(), 2, "xvld $xr12, $s0, \VectorOffset\()" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfmadd.s $xr1, $xr13, $xr12, $xr1" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfmadd.s $xr5, $xr14, $xr12, $xr5" + EmitIfCountGE \FilterCount\(), 3, "xvld $xr12, $t7, \VectorOffset\()" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfmadd.s $xr2, $xr13, $xr12, $xr2" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfmadd.s $xr6, $xr14, $xr12, $xr6" + EmitIfCountGE \FilterCount\(), 4, "add.d $s0, $t7, $a1" + EmitIfCountGE \FilterCount\(), 4, "xvld $xr12, $s0, \VectorOffset\()" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfmadd.s $xr3, $xr13, $xr12, $xr3" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfmadd.s $xr7, $xr14, $xr12, $xr7" +.endif +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a specified number + of filter rows. + +Arguments: + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + KernelType - Supplies the type of kernel to be generated. + + FilterCount - Supplies the number of rows from the filter to process. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description) when + KernelType!=Depthwise. Supplies the address of the filter buffer when + KernelType=Depthwise. + + t7 - Supplies the DilationWidth parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t5 - Supplies the InputStride parameter (see function description). + +--*/ + + .macro ProcessFilterCountN KernelFrame, KernelType, FilterCount + +// +// Process the output blocks that include left padding. +// + + ld.d $t0, $sp, OutputCountLeftPad_arg + beqz $t0, .L\KernelType\().\FilterCount\().ProcessOutputCount + bl MlasConv\KernelType\()FloatSingleLasxFilter\FilterCount\() + +// +// Process the output blocks that do not include any padding. +// + +.L\KernelType\().\FilterCount\().ProcessOutputCount: + ld.d $t0, $sp, OutputCount_arg + li.d $s0, 2 + bltu $t0, $s0, .L\KernelType\().\FilterCount\().ProcessRemainingOutputCount + +.L\KernelType\().\FilterCount\().ProcessNextOutputCountBy2: + ProcessOutputCountN Lasx, \KernelFrame\(), \KernelType\(), 8, \FilterCount\(), 2 + slli.d $s0, $a5, 1 # advance input by 2 elements + add.d $a0, $a0, $s0 + addi.d $t0, $t0, -2 + li.d $s0, 2 + bgeu $t0, $s0, .L\KernelType\().\FilterCount\().ProcessNextOutputCountBy2 + +.L\KernelType\().\FilterCount\().ProcessRemainingOutputCount: + +// +// Process the output blocks that include right padding plus any remaining output +// blocks from above. +// + +.L\KernelType\().\FilterCount\().ProcessOutputCountRightPadAndRemaining: + ld.d $s0, $sp, OutputCountRightPad_arg + add.d $t0, $t0, $s0 + beqz $t0, .L\KernelType\().ExitKernel + bl MlasConv\KernelType\()FloatSingleLasxFilter\FilterCount\() + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a specified number + of filter rows for a pointwise convolution. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + t8 - Supplies the InputStride parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t0 - Supplies the OutputCount parameter (see function description). + + t2 - Supplies the address of the filter buffer. + +--*/ + + .macro ProcessPointwiseFilterCountN FilterCount + li.d $s0, 2 + bltu $t0, $s0, .LPointwise.\FilterCount\().ProcessRemainingOutputCount + +.LPointwise.\FilterCount\().ProcessNextOutputCountBy2: + ProcessPointwiseOutputCountN Lasx, 8, \FilterCount\(), 2 + slli.d $s0, $a5, 1 # advance input by 2 elements + add.d $a0, $a0, $s0 + addi.d $t0, $t0, -2 + li.d $s0, 2 + bgeu $t0, $s0, .LPointwise.\FilterCount\().ProcessNextOutputCountBy2 + +.LPointwise.\FilterCount\().ProcessRemainingOutputCount: + beqz $t0, .LPointwise.ExitKernel + ProcessPointwiseOutputCountN Lasx, 8, \FilterCount\(), 1 + + .endm + +// +// Generate the convolution kernels. +// + + SconvKernelFunction Nchw, 8, Lasx + SconvKernelFunction Nchwc, 8, Lasx, BiasFilter + SconvKernelDepthwiseFunction 8, Lasx + SconvKernelPointwiseFunction Lasx, BiasFilter + +/*++ + +Macro Description: + + This macro generates code to process an output block after the inner + convolution kernel has executed and then stores the output block to the + output buffer. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +--*/ + + .macro PostProcessBlock FilterCount, OutputCount + + .globl MlasConvPostProcessFloatLasxFilter\FilterCount\()Output\OutputCount\() + .hidden MlasConvPostProcessFloatLasxFilter\FilterCount\()Output\OutputCount\() +MlasConvPostProcessFloatLasxFilter\FilterCount\()Output\OutputCount\(): + + .globl MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\() + .hidden MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\() +MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\(): + +.if \FilterCount\() > 2 + slli.d $s0, $t6, 1 # compute output plus 2 rows + add.d $t7, $a4, $s0 +.endif + +// +// Test if the existing contents of the output buffer should be accumulated +// with the output block. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvld $xr16, $a4, 0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfadd.s $xr0, $xr0, $xr16" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvld $xr16, $a4, 32" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfadd.s $xr4, $xr4, $xr16" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvld $xr16, $a4, 0x40" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvfadd.s $xr8, $xr8, $xr16" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvldx $xr16, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfadd.s $xr1, $xr1, $xr16" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "add.d $s0, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvld $xr16, $s0, 0x20" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfadd.s $xr5, $xr5, $xr16" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "add.d $s0, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvld $xr16, $s0, 0x40" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvfadd.s $xr9, $xr9, $xr16" + + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvld $xr16,$t7, 0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfadd.s $xr2, $xr2, $xr16" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvld $xr16,$t7, 0x20" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfadd.s $xr6, $xr6, $xr16" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvld $xr16,$t7, 0x40" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvfadd.s $xr10, $xr10, $xr16" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvldx $xr16,$t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfadd.s $xr3, $xr3, $xr16" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "add.d $s0, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvld $xr16,$s0, 0x20" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfadd.s $xr7, $xr7, $xr16" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "add.d $s0, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvld $xr16,$s0, 0x40" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvfadd.s $xr11, $xr11, $xr16" + + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput: + +// +// Test if the bias buffer should be accumulated with the output block. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition +.if \OutputCount\() == 1 + EmitIfCountGE \FilterCount\(), 1, "xvld $xr16, $a3, 0" + EmitIfCountGE \FilterCount\(), 1, "xvfadd.s $xr0, $xr0, $xr16" + EmitIfCountGE \FilterCount\(), 2, "xvld $xr16, $a3, 0x20" + EmitIfCountGE \FilterCount\(), 2, "xvfadd.s $xr1, $xr1, $xr16" + EmitIfCountGE \FilterCount\(), 3, "xvld $xr16, $a3, 0x40" + EmitIfCountGE \FilterCount\(), 3, "xvfadd.s $xr2, $xr2, $xr16" + EmitIfCountGE \FilterCount\(), 4, "xvld $xr16, $a3, 0x60" + EmitIfCountGE \FilterCount\(), 4, "xvfadd.s $xr3, $xr3, $xr16" +.else + EmitIfCountGE \FilterCount\(), 1, "xvld $xr12, $a3, 0" + EmitIfCountGE \FilterCount\(), 2, "xvld $xr13, $a3, 0x20" + EmitIfCountGE \FilterCount\(), 3, "xvld $xr14, $a3, 0x40" + EmitIfCountGE \FilterCount\(), 4, "xvld $xr15, $a3, 0x60" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfadd.s $xr0, $xr0, $xr12" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfadd.s $xr4, $xr4, $xr12" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvfadd.s $xr8, $xr8, $xr12" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfadd.s $xr1, $xr1, $xr13" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfadd.s $xr5, $xr5, $xr13" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvfadd.s $xr9, $xr9, $xr13" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfadd.s $xr2, $xr2, $xr14" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfadd.s $xr6, $xr6, $xr14" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvfadd.s $xr10, $xr10, $xr14" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfadd.s $xr3, $xr3, $xr15" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfadd.s $xr7, $xr7, $xr15" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvfadd.s $xr11, $xr11, $xr15" + +.endif + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition: + +// +// Test for fused ReLU activation. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation + xvxor.v $xr15, $xr15, $xr15 + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfmax.s $xr0, $xr15, $xr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfmax.s $xr4, $xr15, $xr4" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvfmax.s $xr8, $xr15, $xr8" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfmax.s $xr1, $xr15, $xr1" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfmax.s $xr5, $xr15, $xr5" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvfmax.s $xr9, $xr15, $xr9" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfmax.s $xr2, $xr15, $xr2" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfmax.s $xr6, $xr15, $xr6" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvfmax.s $xr10, $xr15, $xr10" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfmax.s $xr3, $xr15, $xr3" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfmax.s $xr7, $xr15, $xr7" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvfmax.s $xr11, $xr15, $xr11" + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation: + +// +// Store the output block in the output buffer. +// + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvst $xr0, $a4, 0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvst $xr4, $a4, 0x20" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvst $xr8, $a4, 0x40" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvstx $xr1, $a4, $t6" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "add.d $s0, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvst $xr5, $s0, 0x20" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "add.d $s0, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvst $xr9, $s0, 0x40" + + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvst $xr2, $t7, 0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvst $xr6, $t7, 0x20" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvst $xr10, $t7, 0x40" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvstx $xr3, $t7, $t6" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "add.d $s0, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvst $xr7, $s0, 0x20" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "add.d $s0, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvst $xr11, $s0, 0x40" + + add_immed $a4,\OutputCount\()*8*4 # advance output by N nchw8c blocks + jr $ra + + .endm + + .irp FilterCount, 1, 2, 3, 4 + .irp OutputCount, 1, 2, 3 + PostProcessBlock \FilterCount\(), \OutputCount\() + .endr + .endr + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasxCommon.h new file mode 100644 index 0000000000000..bd2db816ed9ab --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasxCommon.h @@ -0,0 +1,868 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SconvKernelLasxCommon.h + +Abstract: + + This module contains common kernel macros and structures for the single + precision convolution operation for the Lasx kernels. + +--*/ + + +#define SP_SIZE 32*8 + +#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 +#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 +#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 +#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 + +#define OutputStride_arg 6*8 +#define KernelHeight_arg 7*8 +#define KernelWidth_arg 8*8 +#define InputBase_arg 9*8 +#define InputWidth_arg 10*8 +#define DilatedInputWidth_arg 11*8 +#define OutputCountLeftPad_arg 12*8 +#define OutputCount_arg 13*8 +#define OutputCountRightPad_arg 14*8 +#define Bias_arg 15*8 +#define Flags_arg 16*8 +#define InputChannels_arg 17*8 +#define Filter_save_offset 18*8 + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a vector of input + blocks and a vector of filter blocks to produce a matrix of output blocks. + + OutputCount=1 generates special case code to handle padding blocks. All + other output counts assume no padding. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + KernelType - Supplies the type of kernel to be generated. + + BlockSize - Supplies the number of elements per block. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description) when + KernelType!=Depthwise. Supplies the address of the filter buffer when + KernelType=Depthwise. + + s8 - Supplies the DilationWidth parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t5 - Supplies the InputStride parameter (see function description). +--*/ + .macro ProcessOutputCountN Isa, KernelFrame, KernelType, BlockSize, FilterCount, OutputCount + + move $a3, $a0 +.ifeqs "\KernelType\()","Depthwise" + move $a2, $a1 +.else + ld.d $a2, $sp, Filter_save_offset +.endif + ld.d $t1, $sp, KernelHeight_arg + ld.d $t2, $sp, KernelWidth_arg +.if \OutputCount\() == 1 + ld.d $t3, $sp, InputBase_arg + ld.d $t4, $sp, InputWidth_arg + sub.d $t3, $zero, $t3 +.endif + ClearBlock \FilterCount\(), \OutputCount\() + beqz $t1, .L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing + +.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow: + move $t6, $t2 # reload kernel width remaining + +.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn: +.if \OutputCount\() == 1 + add.d $t7, $a3, $t3 # compute (Input - InputBase) + # (Input - InputBase) >= InputWidth? + bgeu $t7, $t4, .L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding +.endif +.if \OutputCount\() > 3 + slli.d $s0, $a5, 1 + add.d $s0, $s0, $a5 + add.d $t4, $a3, $s0 # compute input plus 3 blocks +.endif +.if \FilterCount\() > 2 + slli.d $s0, $a1, 1 # compute filter plus 2 rows + add.d $t7, $a2, $s0 +.endif +.ifeqs "\KernelType\()","Nchwc" +.if \BlockSize\() == 16 + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 + .endr +.else + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 + .endr +.endif +.else + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), 0, 0 +.endif + +.L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding: + # advance input by dilation width + add.d $a3, $a3, $t8 +.ifeqs "\KernelType\()","Nchwc" + # advance filter by 8i8o/16i16o block + addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 +.else + addi.d $a2, $a2, \BlockSize\()*4 # advance filter by 8o/16o block +.endif + addi.d $t6, $t6, -1 + bnez $t6, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn + add.d $a3, $a3, $t5 # advance input to next row +.if \OutputCount\() == 1 + ld.d $s0, $sp, DilatedInputWidth_arg + # advance input base to next row + sub.d $t3, $t3, $s0 +.endif + addi.d $t1, $t1, -1 # decrement rows remaining + bnez $t1, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow + +// +// Handle post processing of the output block. +// + +.L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing: + ld.w $a2, $sp, Flags_arg +.if \FilterCount\() > 1 + ld.d $t6, $sp, OutputStride_arg +.endif + ld.d $a3, $sp, Bias_arg + bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() + + .endm + +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel. + +Arguments: + + KernelType - Supplies the type of kernel to be generated. + + BlockSize - Supplies the number of elements per block. + + Isa - Supplies the instruction set architecture string for function tags. + + BiasFilter - Supplies a non-blank value if the address of the filter buffer + should be biased to point to the middle of a OIhw8i8o block in order to + reduce the code size from relative byte offsets. + +--*/ + + .macro SconvKernelFunction KernelType, BlockSize, Isa, BiasFilter + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a4) - Supplies the length in bytes of the blocked dilation + width. + + FilterCount (a5) - Supplies the number of filters to process in this + iteration. + + InputStride (a6)- Supplies the length in bytes to advance the input buffer to + the next input row. + + FilterStride (a7) - Supplies the length in bytes to advance the filter buffer + to the next set of filters. + + OutputStride (sp + 0)- Supplies the length in bytes to advance the output buffer + to the next output address associated with the next set of filters. + + KernelHeight (sp + 8)- Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (sp + 0x10)- Supplies the width of the kernel to apply. + + InputBase (sp + 0x18)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp + 0x20)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp + 0x28)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp + 0x30)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp + 0x38)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp + 0x40)- Supplies the number of output elements that include + one or more padding elements from the right edge. + + Bias (sp + 0x48)- Supplies the address of the bias buffer. + + Flags (sp + 0x50)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConv\KernelType\()FloatKernel\Isa\() + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0 + st.d $s1, $sp, 8 + st.d $s2, $sp, 2*8 + st.d $ra, $sp, 5*8 + + ld.d $t0, $sp, SP_SIZE+0*8 + ld.d $t1, $sp, SP_SIZE+1*8 + ld.d $t2, $sp, SP_SIZE+2*8 + ld.d $t3, $sp, SP_SIZE+3*8 + st.d $t0, $sp, OutputStride_arg + st.d $t1, $sp, KernelHeight_arg + st.d $t2, $sp, KernelWidth_arg + st.d $t3, $sp, InputBase_arg + ld.d $t0, $sp, SP_SIZE+4*8 + ld.d $t1, $sp, SP_SIZE+5*8 + ld.d $t2, $sp, SP_SIZE+6*8 + ld.d $t3, $sp, SP_SIZE+7*8 + st.d $t0, $sp, InputWidth_arg + st.d $t1, $sp, DilatedInputWidth_arg + st.d $t2, $sp, OutputCountLeftPad_arg + st.d $t3, $sp, OutputCount_arg + ld.d $t0, $sp, SP_SIZE+8*8 + ld.d $t1, $sp, SP_SIZE+9*8 + ld.d $t2, $sp, SP_SIZE+10*8 + st.d $t0, $sp, OutputCountRightPad_arg + st.d $t1, $sp, Bias_arg + st.d $t2, $sp, Flags_arg + +.ifeqs "\BiasFilter\()","BiasFilter" + addi.d $a1, $a1, 4*8*4 +.endif + st.d $a1, $sp, Filter_save_offset + move $a1, $a7 + move $t5, $a6 + move $t8, $a4 + move $t1, $a5 + move $a4, $a2 + move $a5, $a3 + +// +// Process the specified number of filter rows. +// + + ori $s0, $zero, 3 + beq $t1, $s0, .L\KernelType\().ProcessFilterCount3 + bltu $t1, $s0, .L\KernelType\().ProcessFilterCountLessThan3 + ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 4 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCount3: + ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 3 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCountLessThan3: + ori $s0, $zero, 2 + bltu $t1, $s0, .L\KernelType\().ProcessFilterCount1 + ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 2 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCount1: + ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 1 + +// +// Restore non-volatile registers and return. +// + +.L\KernelType\().ExitKernel: +.ifnes "\Isa\()","LSX" + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 +.endif + ld.d $s0, $sp, 0 + ld.d $s1, $sp, 8 + ld.d $s2, $sp, 2*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jirl $zero, $ra, 0 + +.ifnes "\Isa\()","LSX" + +// +// Generate out-of-band helpers for handling output blocks involving padding. +// + + .irp FilterCount, 1, 2, 3, 4 + +MlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\(): + st.d $ra, $sp, 19*8 +loopMlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\(): + ProcessOutputCountN \Isa\(), LSconvKernelSingleFrame, \KernelType\(), \BlockSize\(), \FilterCount\(), 1 + add.d $a0, $a0, $a5 # advance input by 1 element + addi.d $t0, $t0, -1 # decrement output count remaining + bnez $t0, loopMlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\() + ld.d $ra, $sp, 19*8 + jr $ra + + .endr + +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel for the special + case of a depthwise separable convolution. + +Arguments: + + BlockSize - Supplies the number of elements per block. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro SconvKernelDepthwiseFunction BlockSize, Isa + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + + Depthwise separable convolutions are a form of grouped convolution where + the number of input and output channels per group are one. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a4) - Supplies the length in bytes of the blocked dilation + width. + + InputStride (a5) - Supplies the length in bytes to advance the input buffer + to the next input row. + + KernelHeight (a6)- Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (a7)- Supplies the width of the kernel to apply. + + InputBase (sp + 0 )- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp + 8 )- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp + 0x10)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp + 0x18)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp + 0x20)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp + 0x28)- Supplies the number of output elements that include + one or more padding elements from the right edge. + + Bias (sp + 0x30)- Supplies the address of the bias buffer. + + Flags (sp + 0x38)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvDepthwiseFloatKernel\Isa\() + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0 + st.d $s1, $sp, 8 + st.d $s2, $sp, 2*8 + st.d $ra, $sp, 5*8 + + st.d $a6, $sp, KernelHeight_arg + st.d $a7, $sp, KernelWidth_arg + + ld.d $t0, $sp, SP_SIZE+0*8 + ld.d $t1, $sp, SP_SIZE+1*8 + ld.d $t2, $sp, SP_SIZE+2*8 + ld.d $t3, $sp, SP_SIZE+3*8 + st.d $t0, $sp, InputBase_arg + st.d $t1, $sp, InputWidth_arg + st.d $t2, $sp, DilatedInputWidth_arg + st.d $t3, $sp, OutputCountLeftPad_arg + ld.d $t0, $sp, SP_SIZE+4*8 + ld.d $t1, $sp, SP_SIZE+5*8 + ld.d $t2, $sp, SP_SIZE+6*8 + ld.d $t3, $sp, SP_SIZE+7*8 + st.d $t0, $sp, OutputCount_arg + st.d $t1, $sp, OutputCountRightPad_arg + st.d $t2, $sp, Bias_arg + st.d $t3, $sp, Flags_arg + + move $t8, $a4 + move $t5, $a5 + move $a4, $a2 + move $a5, $a3 + +// +// Process the specified number of filter rows. +// + + ProcessFilterCountN LSconvKernelDepthwiseFrame, Depthwise, 1 + +// +// Restore non-volatile registers and return. +// + +.LDepthwise.ExitKernel: +.ifnes "\Isa\()","LSX" + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 +.endif + ld.d $s0, $sp, 0 + ld.d $s1, $sp, 8 + ld.d $s2, $sp, 2*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jr $ra + +.ifnes "\Isa\()","LSX" + +// +// Generate out-of-band helpers for handling output blocks involving padding. +// + +MlasConvDepthwiseFloatSingle\Isa\()Filter1: + st.d $ra, $sp, 20*8 +MlasConvDepthwiseFloatSingle\Isa\()Filter1_loop: + ProcessOutputCountN \Isa\(), LSconvKernelDepthwiseSingleFrame, Depthwise, \BlockSize\(), 1, 1 + add.d $a0, $a0, $a5 # advance input by 1 element + addi.d $t0, $t0, -1 # decrement output count remaining + + bnez $t0, MlasConvDepthwiseFloatSingle\Isa\()Filter1_loop + ld.d $ra, $sp, 20*8 + jr $ra + +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a vector of input + blocks and a vector of filter blocks to produce a matrix of output blocks + for a pointwise convolution. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + BlockSize - Supplies the number of elements per block. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + t8 - Supplies the InputStride parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t2 - Supplies the address of the filter buffer. + +--*/ + + .macro ProcessPointwiseOutputCountN Isa, BlockSize, FilterCount, OutputCount + + move $a3, $a0 + move $a2, $t2 + ld.d $t1, $sp, InputChannels_arg + ClearBlock \FilterCount\(), \OutputCount\() + +.LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock: +.if \OutputCount\() > 3 + slli.d $s0, $a5, 1 + add.d $s0, $s0, $a5 + add.d $t4, $s0, $a3 +.endif +.if \FilterCount\() > 2 + slli.d $s0, $a1, 1 + add.d $t7, $a2, $s0 +.endif +.if \BlockSize\() == 16 + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 + .endr +.else + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 + ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 + .endr +.endif + add.d $a3, $a3, $t8 # advance input to next channel block + + addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 # advance filter by 8i8o/16i16o block + addi.d $t1, $t1, -1 # decrement input blocks remaining + + bnez $t1, .LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock + +// +// Handle post processing of the output block. +// + + ld.w $a2, $sp, Flags_arg +.if \FilterCount\() > 1 + ld.d $t6, $sp, OutputStride_arg +.endif + ld.d $a3, $sp, Bias_arg + bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() + + .endm + +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel for the special + case where the kernel dimensions are 1. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + BiasFilter - Supplies a non-blank value if the address of the filter buffer + should be biased to point to the middle of a OIhw8i8o block in order to + reduce the code size from relative byte offsets. + +--*/ + + .macro SconvKernelPointwiseFunction Isa, BiasFilter + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + + Pointwise convolutions have a kernel size of one. To simplify this + implementation, no input padding is allowed, which matches typical usage in + models. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + InputChannels (a4) - Supplies the number of input channels to process. + + FilterCount (a5) - Supplies the number of rows from the filter to process. + + InputStride (a6) - Supplies the length in bytes to advance the input buffer to + the next input channel of the same input row. + + FilterStride (a7) - Supplies the length in bytes to advance the filter buffer + to the next set of filters. + + OutputStride (sp + 0)- Supplies the length in bytes to advance the output buffer + to the next output address associated with the next set of filters. + + OutputCount (sp + 8)- Supplies the number of output elements. + + Bias (sp + 0x10)- Supplies the address of the bias buffer. + + Flags (sp + 0x18)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvPointwiseFloatKernel\Isa\() + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $ra, $sp, 5*8 + + ld.d $t0, $sp, SP_SIZE+0*8 + ld.d $t1, $sp, SP_SIZE+1*8 + ld.d $t2, $sp, SP_SIZE+2*8 + ld.d $t3, $sp, SP_SIZE+3*8 + st.d $t0, $sp, OutputStride_arg + st.d $t1, $sp, OutputCount_arg + st.d $t2, $sp, Bias_arg + st.d $t3, $sp, Flags_arg + st.d $a4, $sp, InputChannels_arg + +.ifeqs "\BiasFilter\()","BiasFilter" + addi.d $t2, $a1, 4*8*4 +.else + move $t2, $a1 +.endif + ld.d $t0, $sp, OutputCount_arg + move $a1, $a7 + move $t8, $a6 + move $t1, $a5 + move $a4, $a2 + move $a5, $a3 + +// +// Process the specified number of filter rows. +// + + ori $s0, $zero, 3 + beq $t1, $s0, .LPointwise.ProcessFilterCount3 + bltu $t1, $s0, .LPointwise.ProcessFilterCountLessThan3 + ProcessPointwiseFilterCountN 4 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCount3: + ProcessPointwiseFilterCountN 3 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCountLessThan3: + ori $s0, $zero, 2 + bltu $t1, $s0, .LPointwise.ProcessFilterCount1 + ProcessPointwiseFilterCountN 2 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCount1: + ProcessPointwiseFilterCountN 1 + +// +// Restore non-volatile registers and return. +// + +.LPointwise.ExitKernel: +.ifnes "\Isa\()","LSX" + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 +.endif + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jr $ra + + .endm + +/*++ + +Macro Description: + + This macro generates code to clear the block accumulators. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + xr0-xr11 - Supplies the block accumulators. + +--*/ + + .macro ClearBlock FilterCount, OutputCount + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvxor.v $xr0, $xr0, $xr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvxor.v $xr4, $xr4, $xr4" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvxor.v $xr8, $xr8, $xr8" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvxor.v $xr1, $xr1, $xr1" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvxor.v $xr5, $xr5, $xr5" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvxor.v $xr9, $xr9, $xr9" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvxor.v $xr2, $xr2, $xr2" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvxor.v $xr6, $xr6, $xr6" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvxor.v $xr10, $xr10, $xr10" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvxor.v $xr3, $xr3, $xr3" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvxor.v $xr7, $xr7, $xr7" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvxor.v $xr11, $xr11, $xr11" + + .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsx.S b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsx.S new file mode 100644 index 0000000000000..04b8dc14d067d --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsx.S @@ -0,0 +1,339 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SconvKernelLsx.S + +Abstract: + + This module implements the kernels for the single precision convolution + operation. + + This implementation uses Lsx instructions. + +--*/ + +#include "asmmacro.h" +#include "SconvKernelLsxCommon.h" + +/*++ + +Macro Description: + + This macro generates code to clear the block accumulators. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + vr0-vr7 - Supplies the block accumulators. + +--*/ + + .macro ClearBlock FilterCount, OutputCount + + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vxor.v $vr0,$vr0,$vr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vxor.v $vr1,$vr1,$vr1" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vxor.v $vr2,$vr2,$vr2" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vxor.v $vr3,$vr3,$vr3" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vxor.v $vr4,$vr4,$vr4" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vxor.v $vr5,$vr5,$vr5" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vxor.v $vr6,$vr6,$vr6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vxor.v $vr7,$vr7,$vr7" + + .endm + +/*++ + +Macro Description: + + This macro multiplies and accumulates for FilterCount by OutputCount block + of the output buffer. + +Arguments: + + KernelType - Supplies the type of kernel to be generated. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + + VectorOffset - Supplies the byte offset from the filter buffer to fetch + elements. + + BroadcastOffset - Supplies the byte offset from the input buffer to fetch + elements. + +Implicit Arguments: + + a3 - Supplies the address of the input buffer. + + a2 - Supplies the address of the filter buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + t6 - Supplies the address of the filter buffer plus 2 * FilterStride. + + a5 - Supplies the StrideWidth parameter (see function description). + + vr0-vr7 - Supplies the block accumulators. + +--*/ + .macro ComputeBlock KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset + +.ifeqs "\KernelType\()","Depthwise" + vld $vr8, $a2, 0 + vld $vr9, $a2, 16 + vld $vr10, $a3, 0 + vld $vr11, $a3, 16 + vfmadd.s $vr0, $vr8, $vr10, $vr0 + vfmadd.s $vr1, $vr9, $vr11, $vr1 +.else + EmitIfCountGE \OutputCount\(), 1, "ld.w $s0, $a3, \BroadcastOffset\()" + EmitIfCountGE \OutputCount\(), 1, "vreplgr2vr.w $vr12, $s0" + EmitIfCountGE \FilterCount\(), 1, "vld $vr8, $a2, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 1, "vld $vr9, $a2, \VectorOffset\()+16" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmadd.s $vr0, $vr8, $vr12, $vr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmadd.s $vr1, $vr9, $vr12, $vr1" + EmitIfCountGE \FilterCount\(), 2, "addi.d $s0, $a1, +\VectorOffset\()" + EmitIfCountGE \FilterCount\(), 2, "vldx $vr8, $a2, $s0" + EmitIfCountGE \FilterCount\(), 2, "addi.d $s0, $a1, +\VectorOffset\()+16" + EmitIfCountGE \FilterCount\(), 2, "vldx $vr9, $a2, $s0" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmadd.s $vr2, $vr8, $vr12, $vr2" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmadd.s $vr3, $vr9, $vr12, $vr3" + EmitIfCountGE \FilterCount\(), 3, "vld $vr8, $t7, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 3, "vld $vr9, $t7, \VectorOffset\()+16" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmadd.s $vr4, $vr8, $vr12, $vr4" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmadd.s $vr5, $vr9, $vr12, $vr5" + EmitIfCountGE \FilterCount\(), 4, "addi.d $s0, $a1, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 4, "vldx $vr8, $t7, $s0" + EmitIfCountGE \FilterCount\(), 4, "addi.d $s0, $a1, \VectorOffset\()+16" + EmitIfCountGE \FilterCount\(), 4, "vldx $vr9, $t7, $s0" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmadd.s $vr6, $vr8, $vr12, $vr6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmadd.s $vr7, $vr9, $vr12, $vr7" +.endif + .endm +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a specified number + of filter rows. + +Arguments: + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + KernelType - Supplies the type of kernel to be generated. + + FilterCount - Supplies the number of rows from the filter to process. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description) when + KernelType!=Depthwise. Supplies the address of the filter buffer when + KernelType=Depthwise. + + s8 - Supplies the DilationWidth parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + s3 - Supplies the InputStride parameter (see function description). + +--*/ + + .macro ProcessFilterCountN KernelFrame, KernelType, FilterCount + ld.d $s0, $sp, OutputCountLeftPad_arg //OutputCountLeftPad + ld.d $s1, $sp, OutputCount_arg //OutputCount + add.d $s0, $s0, $s1 + ld.d $s1, $sp, OutputCountRightPad_arg //OutputCountRightPad + add.d $t0, $s0, $s1 +.L\KernelType\().\FilterCount\().ProcessNextOutputCount: + ProcessOutputCountN Sse, \KernelFrame\(), \KernelType\(), 8, \FilterCount\(), 1 + add.d $a0, $a0, $a5 + addi.d $t0, $t0, -1 + bnez $t0, .L\KernelType\().\FilterCount\().ProcessNextOutputCount + .endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a specified number + of filter rows for a pointwise convolution. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + s8 - Supplies the InputStride parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t7 - Supplies the OutputCount parameter (see function description). + + s5 - Supplies the address of the filter buffer. + +--*/ + + .macro ProcessPointwiseFilterCountN FilterCount +.LPointwise.\FilterCount\().ProcessNextOutputCount: + ProcessPointwiseOutputCountN Sse, 8, \FilterCount\(), 1 + add.d $a0, $a0, $a5 + addi.d $t0, $t0, -1 + bnez $t0, .LPointwise.\FilterCount\().ProcessNextOutputCount + .endm + +// +// Generate the convolution kernels. +// + + SconvKernelFunction Nchw, 8, LSX + SconvKernelFunction Nchwc, 8, LSX, BiasFilter + SconvKernelDepthwiseFunction 8, LSX + SconvKernelPointwiseFunction LSX, BiasFilter + +/*++ + +Macro Description: + + This macro generates code to process an output block after the inner + convolution kernel has executed and then stores the output block to the + output buffer. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. +--*/ + + .macro PostProcessBlock FilterCount, OutputCount + + .globl MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\() +#if !defined(__APPLE__) + .hidden MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\() +#endif +MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\(): + +.if \FilterCount\() > 2 + li.d $s0, 2 + mul.d $s0, $s0, $t6 + add.d $t7, $a4, $s0 +.endif + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT + andi $s0, $s0, 0xff + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr8, $a4, 0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr9, $a4, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vldx $vr10, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addi.d $s0, $t6, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vldx $vr11, $a4, $s0" + + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr12, $t7, 0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr13, $t7, 16" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vldx $vr14, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addi.d $s0, $t6, 16" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vldx $vr15, $t7, $s0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr0, $vr0, $vr8" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr1, $vr1, $vr9" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr2, $vr2, $vr10" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr3, $vr3, $vr11" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr4, $vr4, $vr12" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr5, $vr5, $vr13" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr6, $vr6, $vr14" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr7, $vr7, $vr15" + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput: +// +// Test if the bias buffer should be accumulated with the output block. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION + andi $s0, $s0, 0xff + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr8, $a3, 0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr9, $a3, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vld $vr10, $a3, 32" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vld $vr11, $a3, 48" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr12, $a3, 64" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr13, $a3, 80" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vld $vr14, $a3, 96" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vld $vr15, $a3, 112" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr0, $vr0, $vr8" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr1, $vr1, $vr9" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr2, $vr2, $vr10" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr3, $vr3, $vr11" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr4, $vr4, $vr12" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr5, $vr5, $vr13" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr6, $vr6, $vr14" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr7, $vr7, $vr15" + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition: + +// +// Test for fused ReLU activation. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION + andi $s0, $s0, 0xff + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation + vxor.v $vr15,$vr15, $vr15 + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmax.s $vr0, $vr0, $vr15" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmax.s $vr1, $vr1, $vr15" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmax.s $vr2, $vr2, $vr15" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmax.s $vr3, $vr3, $vr15" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmax.s $vr4, $vr4, $vr15" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmax.s $vr5, $vr5, $vr15" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmax.s $vr6, $vr6, $vr15" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmax.s $vr7, $vr7, $vr15" + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation: + +// +// Store the output block in the output buffer. +// + + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vst $vr0, $a4,0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vst $vr1, $a4, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vstx $vr2, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addi.d $s0, $t6, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vstx $vr3, $a4, $s0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vst $vr4, $t7, 0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vst $vr5, $t7, 16" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vstx $vr6, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addi.d $s0, $t6, 16" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vstx $vr7, $t7, $s0" + add_immed $a4, \OutputCount\()*8*4 # advance output by N nchw8c blocks + jr $ra + + .endm + + .irp FilterCount, 1, 2, 3, 4 + .irp OutputCount, 1 + PostProcessBlock \FilterCount\(), \OutputCount\() + .endr + .endr + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsxCommon.h new file mode 100644 index 0000000000000..d03714f654500 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsxCommon.h @@ -0,0 +1,669 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SconvKernelLsxCommon.h + +Abstract: + + This module contains common kernel macros and structures for the single + precision convolution operation for the Lsx kernels. + +--*/ + +#define SP_SIZE 32*8 + +#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 +#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 +#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 +#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 + +#define Filter_save_offset 18*8 + +#define OutputStride_arg 6*8 +#define KernelHeight_arg 7*8 +#define KernelWidth_arg 8*8 +#define InputBase_arg 9*8 +#define InputWidth_arg 10*8 +#define DilatedInputWidth_arg 11*8 +#define OutputCountLeftPad_arg 12*8 +#define OutputCount_arg 13*8 +#define OutputCountRightPad_arg 14*8 +#define Bias_arg 15*8 +#define Flags_arg 16*8 +#define InputChannels_arg 17*8 + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a vector of input + blocks and a vector of filter blocks to produce a matrix of output blocks. + + OutputCount=1 generates special case code to handle padding blocks. All + other output counts assume no padding. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + KernelType - Supplies the type of kernel to be generated. + + BlockSize - Supplies the number of elements per block. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description) when + KernelType!=Depthwise. Supplies the address of the filter buffer when + KernelType=Depthwise. + + s8 - Supplies the DilationWidth parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + s3 - Supplies the InputStride parameter (see function description). +--*/ + + .macro ProcessOutputCountN Isa, KernelFrame, KernelType, BlockSize, FilterCount, OutputCount + move $a3, $a0 +.ifeqs "\KernelType\()","Depthwise" + move $a2, $a1 +.else + ld.d $a2, $sp, Filter_save_offset +.endif + ld.d $t1, $sp, KernelHeight_arg //KernelHeight + ld.d $t2, $sp, KernelWidth_arg //KernelWidth +.if \OutputCount\() == 1 + ld.d $t3, $sp, InputBase_arg //InputBase + ld.d $t4, $sp, InputWidth_arg //InputWidth + sub.d $t3, $zero, $t3 # keep negative for lea usage below +.endif + ClearBlock \FilterCount\(), \OutputCount\() + beqz $t1, .L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing + +.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow: + move $t6, $t2 # reload kernel width remaining +.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn: +.if \OutputCount\() == 1 + add.d $t7, $a3, $t3 + bgeu $t7, $t4, .L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding +.endif +.if \OutputCount\() > 3 + li.d $s2, 2 + mul.d $s2, $a5, $s2 + add.d $t4, $a5, $s2 + + add.d $t4, $t4, $a3 # compute input plus 3 blocks +.endif +.if \FilterCount\() > 2 + li.d $s2, 2 + mul.d $s2, $s2, $a1 + add.d $t7, $a2, $s2 //t6 is rbx used by ComputeBlock +.endif +.ifeqs "\KernelType\()","Nchwc" +.if \BlockSize\() == 16 + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 + .endr +.else + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 + .endr +.endif +.else + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), 0, 0 +.endif +.L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding: + add.d $a3, $a3, $t8 # advance input by dilation width +.ifeqs "\KernelType\()","Nchwc" + addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 + # advance filter by 8i8o/16i16o block +.else + addi.d $a2, $a2, \BlockSize\()*4 # advance filter by 8o/16o block +.endif + addi.d $t6, $t6, -1 # decrement columns remaining + bnez $t6, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn + add.d $a3, $a3, $t5 +.if \OutputCount\() == 1 + ld.d $s0, $sp, DilatedInputWidth_arg #DilatedInputWidth + sub.d $t3, $t3, $s0 + # advance input base to next row +.endif + addi.d $t1, $t1, -1 # decrement rows remaining + bnez $t1, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow + +// +// Handle post processing of the output block. +// +.L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing: + ld.w $a2, $sp, Flags_arg + +.if \FilterCount\() > 1 + ld.d $t6, $sp, OutputStride_arg +.endif + ld.d $a3, $sp, Bias_arg + bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() +.endm +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel. + +Arguments: + + KernelType - Supplies the type of kernel to be generated. + + BlockSize - Supplies the number of elements per block. + + Isa - Supplies the instruction set architecture string for function tags. + + BiasFilter - Supplies a non-blank value if the address of the filter buffer + should be biased to point to the middle of a OIhw8i8o block in order to + reduce the code size from relative byte offsets. + +--*/ + + .macro SconvKernelFunction KernelType, BlockSize, Isa, BiasFilter + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a4) - Supplies the length in bytes of the blocked dilation + width. + + FilterCount (a5) - Supplies the number of filters to process in this + iteration. + + InputStride (a6) - Supplies the length in bytes to advance the input buffer to + the next input row. + + FilterStride (a7)- Supplies the length in bytes to advance the filter buffer + to the next set of filters. + + OutputStride (sp,8*0) - Supplies the length in bytes to advance the output buffer + to the next output address associated with the next set of filters. + + KernelHeight (sp,8*1)- Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (sp, 8*2)- Supplies the width of the kernel to apply. + + InputBase (sp, 8*3)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp, 8*4)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp, 8*5)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp, 8*6)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp, 8*7)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp, 8*8)- Supplies the number of output elements that include + one or more padding elements from the right edge. + + Bias (sp, 8*9)- Supplies the address of the bias buffer. + + Flags (sp, 8*10)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConv\KernelType\()FloatKernel\Isa\() + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $s3, $sp, 3*8 + st.d $s4, $sp, 4*8 + st.d $ra, $sp, 5*8 + ld.d $s0, $sp, SP_SIZE+0*8 + ld.d $s1, $sp, SP_SIZE+1*8 + ld.d $s2, $sp, SP_SIZE+2*8 + ld.d $s3, $sp, SP_SIZE+3*8 + st.d $s0, $sp, OutputStride_arg + st.d $s1, $sp, KernelHeight_arg + st.d $s2, $sp, KernelWidth_arg + st.d $s3, $sp, InputBase_arg + ld.d $s0, $sp, SP_SIZE+4*8 + ld.d $s1, $sp, SP_SIZE+5*8 + ld.d $s2, $sp, SP_SIZE+6*8 + ld.d $s3, $sp, SP_SIZE+7*8 + st.d $s0, $sp, InputWidth_arg + st.d $s1, $sp, DilatedInputWidth_arg + st.d $s2, $sp, OutputCountLeftPad_arg + st.d $s3, $sp, OutputCount_arg + ld.d $s0, $sp, SP_SIZE+8*8 + ld.d $s1, $sp, SP_SIZE+9*8 + ld.d $s2, $sp, SP_SIZE+10*8 + st.d $s0, $sp, OutputCountRightPad_arg + st.d $s1, $sp, Bias_arg + st.d $s2, $sp, Flags_arg + +.ifeqs "\BiasFilter\()","BiasFilter" + addi.d $a1, $a1,4*8*4 +.endif + st.d $a1, $sp, Filter_save_offset //store Filter + move $a1, $a7 + move $t5, $a6 + move $t8, $a4 # shuffle to Win64 register usage + move $t1, $a5 + move $a4, $a2 + move $a5, $a3 + + li.d $s0, 3 + beq $t1, $s0, .L\KernelType\().ProcessFilterCount3 + blt $t1, $s0, .L\KernelType\().ProcessFilterCountLessThan3 + ProcessFilterCountN SconvKernelFrame, \KernelType\(), 4 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCount3: + ProcessFilterCountN SconvKernelFrame, \KernelType\(), 3 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCountLessThan3: + li.d $s0,2 + blt $t1, $s0, .L\KernelType\().ProcessFilterCount1 + ProcessFilterCountN SconvKernelFrame, \KernelType\(), 2 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCount1: + ProcessFilterCountN SconvKernelFrame, \KernelType\(), 1 + +// +// Restore non-volatile registers and return. +// + +.L\KernelType\().ExitKernel: + ld.d $a1, $sp, Filter_save_offset //restore Filter + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $s3, $sp, 3*8 + ld.d $s4, $sp, 4*8 + ld.d $ra, $sp, 5*8 + + addi.d $sp, $sp, SP_SIZE + jr $ra +.endm + +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel for the special + case of a depthwise separable convolution. + +Arguments: + + BlockSize - Supplies the number of elements per block. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro SconvKernelDepthwiseFunction BlockSize, Isa + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + + Depthwise separable convolutions are a form of grouped convolution where + the number of input and output channels per group are one. + +Arguments: + + Input a0 - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Filter a1 - Supplies the address of the filter buffer. + + Output a2 - Supplies the address of the output buffer. + + StrideWidth a3 - Supplies the length in bytes of the blocked stride width. + + DilationWidth a4 - Supplies the length in bytes of the blocked dilation + width. + + InputStride a5 - Supplies the length in bytes to advance the input buffer + to the next input row. + + KernelHeight a6 - Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth a7- Supplies the width of the kernel to apply. + + InputBase (sp, 0*8)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp, 1*8)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp, 2*8)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp, 3*8)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp, 4*8)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp, 5*8)- Supplies the number of output elements that include + one or more padding elements from the right edge. + + Bias (sp, 6*8)- Supplies the address of the bias buffer. + + Flags (sp, 7*8)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvDepthwiseFloatKernel\Isa\() + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $s3, $sp, 3*8 + st.d $s4, $sp, 4*8 + st.d $ra, $sp, 5*8 + + st.d $a6, $sp, KernelHeight_arg + st.d $a7, $sp, KernelWidth_arg + + ld.d $s0, $sp, SP_SIZE+0*8 + ld.d $s1, $sp, SP_SIZE+1*8 + ld.d $s2, $sp, SP_SIZE+2*8 + ld.d $s3, $sp, SP_SIZE+3*8 + st.d $s0, $sp, InputBase_arg + st.d $s1, $sp, InputWidth_arg + st.d $s2, $sp, DilatedInputWidth_arg + st.d $s3, $sp, OutputCountLeftPad_arg + ld.d $s0, $sp, SP_SIZE+4*8 + ld.d $s1, $sp, SP_SIZE+5*8 + ld.d $s2, $sp, SP_SIZE+6*8 + ld.d $s3, $sp, SP_SIZE+7*8 + st.d $s0, $sp, OutputCount_arg + st.d $s1, $sp, OutputCountRightPad_arg + st.d $s2, $sp, Bias_arg + st.d $s3, $sp, Flags_arg +// +// Process the specified number of filter rows. +// + move $t8, $a4 // shuffle to Win64 register usage + move $t5, $a5 + move $a4, $a2 + move $a5, $a3 + ProcessFilterCountN SconvKernelDepthwiseFrame, Depthwise, 1 + +// +// Restore non-volatile registers and return. + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $s3, $sp, 3*8 + ld.d $s4, $sp, 4*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE +// + jr $ra +.endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a vector of input + blocks and a vector of filter blocks to produce a matrix of output blocks + for a pointwise convolution. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + BlockSize - Supplies the number of elements per block. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + (a0) - Supplies the address of the input buffer. + + (a1) - Supplies the FilterStride parameter (see function description). + + (s8) - Supplies the InputStride parameter (see function description). + + (a4) - Supplies the address of the output buffer. + + (a5) - Supplies the StrideWidth parameter (see function description). + + (s5) - Supplies the address of the filter buffer. + +--*/ + + .macro ProcessPointwiseOutputCountN Isa, BlockSize, FilterCount, OutputCount + + move $a3, $a0 + move $a2, $t2 + ld.d $t1, $sp, InputChannels_arg + ClearBlock \FilterCount\(), \OutputCount\() + +.LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock: +.if \OutputCount\() > 3 + li.d $s0, 2 + mul $s0, $s0, $a5 + add.d $t4, $a5, $s0 + add.d $t4, $t4, $a3 # compute input plus 3 blocks +.endif +.if \FilterCount\() > 2 + li.d $s0, 2 # compute filter plus 2 rows + mul.d $s0, $s0, $a1 + add.d $t7, $a2, $s0 +.endif + +.if \BlockSize\() == 16 + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 + .endr +.else + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 + ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 + .endr +.endif + add.d $a3, $a3, $t8 # advance input to next channel block + addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 + # advance filter by 8i8o/16i16o block + addi.d $t1, $t1, -1 //InputChannels decrement input blocks remaining + bnez $t1, .LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock + +// +// Handle post processing of the output block. +// + ld.w $a2, $sp, Flags_arg #load flag +.if \FilterCount\() > 1 + ld.d $t6 ,$sp, OutputStride_arg #load .LSconvKernelPointwiseFrame_OutputStride +.endif + ld.d $a3, $sp, Bias_arg # load .LSconvKernelPointwiseFrame_Bias + bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() +.endm + + .macro SconvKernelPointwiseFunction Isa, BiasFilter + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + + Pointwise convolutions have a kernel size of one. To simplify this + implementation, no input padding is allowed, which matches typical usage in + models. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + InputChannels (a4) - Supplies the number of input channels to process. + + FilterCount (a5) - Supplies the number of rows from the filter to process. + + InputStride (a6) - Supplies the length in bytes to advance the input buffer to + the next input channel of the same input row. + + FilterStride (a7) - Supplies the length in bytes to advance the filter buffer + to the next set of filters. + + OutputStride (sp+0) - Supplies the length in bytes to advance the output buffer + to the next output address associated with the next set of filters. + + OutputCount (sp+8) - Supplies the number of output elements. + + Bias (sp+16) - Supplies the address of the bias buffer. + + Flags (sp+24) - Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvPointwiseFloatKernel\Isa\() + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $s3, $sp, 3*8 + st.d $s4, $sp, 4*8 + st.d $ra, $sp, 5*8 + + ld.d $s0, $sp, SP_SIZE+0*8 + ld.d $s1, $sp, SP_SIZE+1*8 + ld.d $s2, $sp, SP_SIZE+2*8 + ld.d $s3, $sp, SP_SIZE+3*8 + st.d $s0, $sp, OutputStride_arg + st.d $s1, $sp, OutputCount_arg + st.d $s2, $sp, Bias_arg + st.d $s3, $sp, Flags_arg + st.d $a4, $sp, InputChannels_arg + +.ifeqs "\BiasFilter\()","BiasFilter" + addi.d $t2, $a1, 4*8*4 +.else + move $t2, $a1 +.endif + + ld.d $t0, $sp, OutputCount_arg //OutputCount + move $a1, $a7 // FilterStride + move $t8, $a6 // InputStride + move $t1, $a5 // shuffle to Win64 register usage + move $a4, $a2 + move $a5, $a3 + +// +// Process the specified number of filter rows. +// + li.d $s0, 3 + beq $t1, $s0, .LPointwise.ProcessFilterCount3 + blt $t1, $s0, .LPointwise.ProcessFilterCountLessThan3 + ProcessPointwiseFilterCountN 4 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCount3: + ProcessPointwiseFilterCountN 3 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCountLessThan3: + li.d $s0, 2 + blt $t1, $s0, .LPointwise.ProcessFilterCount1 + ProcessPointwiseFilterCountN 2 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCount1: + ProcessPointwiseFilterCountN 1 + +// +// Restore non-volatile registers and return. +// +.LPointwise.ExitKernel: + + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $s3, $sp, 3*8 + ld.d $s4, $sp, 4*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jr $ra +.endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelCommon.h b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelCommon.h new file mode 100644 index 0000000000000..93b109c90ae4f --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelCommon.h @@ -0,0 +1,35 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelCommon.h + +Abstract: + + This module contains common kernel macros and structures for the single + precision matrix/matrix multiply operation (SGEMM). + +--*/ + +// +// Define the single precision parameters. +// + +#define LFgemmElementShift 2 +#define LFgemmElementSize (1 << LFgemmElementShift) +#define LFgemmYmmElementCount (32/LFgemmElementSize) + +#include "FgemmKernelCommon.h" + +// +// Define the typed instructions for single precision. +// + +FGEMM_TYPED_INSTRUCTION(xvfadd, xvfadd.s) +FGEMM_TYPED_INSTRUCTION(xvfmadd, xvfmadd.s) +FGEMM_TYPED_INSTRUCTION(xvldrepl, xvldrepl.w) +FGEMM_TYPED_INSTRUCTION(xvfmul, xvfmul.s) diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLasx.S new file mode 100644 index 0000000000000..d537742016d01 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLasx.S @@ -0,0 +1,33 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelLasx.s + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + + This implementation uses LASX instructions. + +--*/ + +#include "asmmacro.h" +#include "SgemmKernelCommon.h" +#include "FgemmKernelLasxCommon.h" + + + .text + +// +// Generate the GEMM kernel. +// + +FgemmKernelLasxFunction MlasGemmFloatKernelLasx + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLsx.S b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLsx.S new file mode 100644 index 0000000000000..86b5ef8b51b00 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLsx.S @@ -0,0 +1,267 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelLsx.s + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + + This implementation uses Lsx instructions. + +--*/ + +#include "asmmacro.h" +#include "FgemmKernelLsxCommon.h" + +FGEMM_TYPED_INSTRUCTION(vfadd, vfadd.s) + +/*++ + +Macro Description: + + This macro multiplies and accumulates for a 16xN block of the output matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + Shuffle - Supplies the shuffle mask to extract the element from matrix A. + +Implicit Arguments: + + a1 - Supplies the address into the matrix B data. + + vr0-vr1 - Supplies up to four elements loaded from matrix A and matrix A + plus one row. + + vr8-vr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockSseBy16 RowCount, VectorOffset, Shuffle + vld $vr4, $a1, \VectorOffset + vld $vr5, $a1, \VectorOffset + 16 + vreplvei.w $vr2, $vr0, \Shuffle +.if \RowCount\() == 2 + vreplvei.w $vr3, $vr1, \Shuffle + vmove $vr6, $vr4 + vmove $vr7, $vr5 +.endif + vfmadd.s $vr8, $vr4, $vr2, $vr8 + vfmadd.s $vr9, $vr5, $vr2, $vr9 +.if \RowCount\() == 2 + vfmadd.s $vr12, $vr6, $vr3, $vr12 + vfmadd.s $vr13, $vr7, $vr3, $vr13 +.endif + vld $vr4, $a1, \VectorOffset + 32 + vld $vr5, $a1, \VectorOffset + 48 +.if \RowCount\() == 2 + vmove $vr6, $vr4 + vmove $vr7, $vr5 +.endif + vfmadd.s $vr10, $vr4, $vr2, $vr10 + vfmadd.s $vr11, $vr5, $vr2, $vr11 +.if \RowCount\() == 2 + vfmadd.s $vr14, $vr6, $vr3, $vr14 + vfmadd.s $vr15, $vr7, $vr3, $vr15 +.endif + .endm + + +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + + Fallthrough - Supplies a non-blank value if the macro may fall through to + the ExitKernel label. + +Implicit Arguments: + + a0 - Supplies the address of matrix A. + + a1 - Supplies the address of matrix B. + + t8 - Supplies the address of matrix A. + + a5 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + a2 - Supplies the address of matrix C. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + t7 - Supplies the length in bytes of a row from matrix A. + + t5 - Supplies the length in bytes of a row from matrix C. + + s3 - Stores the ZeroMode argument from the stack frame. + +--*/ + + .macro ProcessCountM RowCount, Fallthrough +.LProcessNextColumnLoop16xN\@: + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr8, $vr8,$vr8" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr9, $vr9,$vr9" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr10, $vr10,$vr10" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr11, $vr11,$vr11" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr12, $vr12,$vr12" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr13, $vr13,$vr13" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr14, $vr14,$vr14" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr15, $vr15,$vr15" + move $t8, $a3 + li.d $s0, 4 + blt $t8, $s0, .LProcessRemaining16xNBlocks\@ +.LCompute16xNBlockBy4Loop\@: + EmitIfCountGE \RowCount\(), 1, "vld $vr0, $a0, 0" + EmitIfCountGE \RowCount\(), 2, "vldx $vr1, $a0, $t0" #second line of A + ComputeBlockSseBy16 2, 0, 0x0 + ComputeBlockSseBy16 2, 16*4, 0x1 + addi.d $a1, $a1, 32*4 # advance matrix B by 32 columns + ComputeBlockSseBy16 2, 0, 0x2 + ComputeBlockSseBy16 2, 16*4, 0x3 + addi.d $a1, $a1, 32*4 # advance matrix B by 32 columns + addi.d $a0, $a0, 4*4 # advance matrix A by 4 columns + addi.d $t8, $t8, -4 + li.d $s0, 4 #check matrix A remaining less than 4 + bge $t8, $s0, .LCompute16xNBlockBy4Loop\@ + +.LProcessRemaining16xNBlocks\@: + beqz $t8, .LOutput16xNBlock\@ + +.LCompute16xNBlockBy1Loop\@: + EmitIfCountGE \RowCount\(), 1, "ld.w $s0, $a0, 0" + EmitIfCountGE \RowCount\(), 1, "vinsgr2vr.w $vr0, $s0, 0" + EmitIfCountGE \RowCount\(), 2, "ldx.w $s0,$a0, $t0" + EmitIfCountGE \RowCount\(), 2, "vinsgr2vr.w $vr1,$s0, 0" + ComputeBlockSseBy16 2, 0, 0x00 + addi.d $a1, $a1, 16*4 #advance matrix B by 16 columns + addi.d $a0, $a0, 1*4 #advance matrix A by 1 column + addi.d $t8, $t8, -1 + bnez $t8, .LCompute16xNBlockBy1Loop\@ + +.LOutput16xNBlock\@: + movfr2gr.s $s0, $f24 + vreplgr2vr.w $vr2, $s0 + EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr8,$vr8,$vr2" + # multiply by alpha + EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr9,$vr9,$vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr10,$vr10,$vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr11,$vr11,$vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr12,$vr12,$vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr13,$vr13,$vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr14,$vr14,$vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr15,$vr15,$vr2" + li.d $s0, 16 + blt $a5, $s0, .LOutputPartial16xNBlock\@ + sub.d $a5, $a5, $s0 + AccumulateAndStoreBlock \RowCount\(), 4 + addi.d $a2, $a2, 16*4 # advance matrix C by 16 columns + move $a0, $t1 # reload matrix A + bnez $a5, .LProcessNextColumnLoop16xN\@ + b .LExitKernel + +// +// Output a partial 16xN block to the matrix. +// + +.LOutputPartial16xNBlock\@: + li.d $s0, 4 + blt $a5, $s0, .LOutputPartialLessThan4xNBlock\@ + li.d $s0, 8 + blt $a5, $s0, .LOutputPartialLessThan8xNBlock\@ + li.d $s0, 12 + blt $a5, $s0, .LOutputPartialLessThan12xNBlock\@ + AccumulateAndStoreBlock \RowCount\(), 3 + andi $a5, $a5, 3 + beqz $a5, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8, $vr11" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12, $vr15" + addi.d $a2, $a2,12*4 # advance matrix C by 12 columns + b .LOutputPartialLessThan4xNBlock\@ + +.LOutputPartialLessThan12xNBlock\@: + AccumulateAndStoreBlock \RowCount\(), 2 + andi $a5, $a5, 3 + beqz $a5, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8, $vr10" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12, $vr14" + addi.d $a2, $a2,8*4 # advance matrix C by 8 columns + b .LOutputPartialLessThan4xNBlock\@ + +.LOutputPartialLessThan8xNBlock\@: + AccumulateAndStoreBlock \RowCount\(), 1 + andi $a5, $a5, 3 + beqz $a5, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8, $vr9" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12, $vr13" + addi.d $a2, $a2, 4*4 # advance matrix C by 4 columns + +.LOutputPartialLessThan4xNBlock\@: + andi $s0, $a5, 2 + beqz $s0, .LOutputPartial1xNBlock\@ + and $s0, $t5, $t5 # ZeroMode? + bnez $s0, .LSkipAccumulateOutput2xN\@ + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr0, $vr0, $vr0" + EmitIfCountGE \RowCount\(), 1, "ld.d $s0, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "vinsgr2vr.d $vr0, $s0, 0" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr1, $vr1, $vr1" + EmitIfCountGE \RowCount\(), 2, "ldx.d $s0, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "vinsgr2vr.d $vr1, $s0, 0" + EmitIfCountGE \RowCount\(), 1, "vfadd.s $vr8, $vr8, $vr0" + EmitIfCountGE \RowCount\(), 2, "vfadd.s $vr12, $vr12, $vr1" + +.LSkipAccumulateOutput2xN\@: + EmitIfCountGE \RowCount\(), 1, "vstelm.d $vr8, $a2, 0, 0" + EmitIfCountGE \RowCount\(), 2, "vpickve2gr.d $s0, $vr12, 0" + EmitIfCountGE \RowCount\(), 2, "stx.d $s0, $a2, $t6" + andi $s0, $a5, 1 + beqz $s0, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vpermi.w $vr8, $vr8, 0xee" + # shift third element down + EmitIfCountGE \RowCount\(), 2, "vpermi.w $vr12, $vr12, 0xee" + addi.d $a2, $a2, 2*4 # advance matrix C by 2 columns + +.LOutputPartial1xNBlock\@: + and $s0, $t5, $t5 # ZeroMode? + bnez $s0, .LSkipAccumulateOutput1xN\@ + + EmitIfCountGE \RowCount\(), 1, "fld.s $f16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "fadd.s $f8, $f16, $f8" + EmitIfCountGE \RowCount\(), 2, "fldx.s $f17, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "fadd.s $f12, $f12, $f17" + +.LSkipAccumulateOutput1xN\@: + EmitIfCountGE \RowCount\(), 1, "fst.s $f8, $a2, 0" + EmitIfCountGE \RowCount\(), 2, "fstx.s $f12, $a2, $t6" +.ifb \Fallthrough\() + b .LExitKernel +.endif + .endm + +// +// Generate the GEMM kernel. +// + +FgemmKernelLsxFunction MlasGemmFloatKernelLSX + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4LSX.S b/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4LSX.S new file mode 100644 index 0000000000000..cd1747745d2a4 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4LSX.S @@ -0,0 +1,89 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmTransposePackB16x4LSX.s + +Abstract: + + This module implements routines for packing buffers for the single precision + matrix/matrix multiply operation (SGEMM). + + This implementation uses Lsx instructions. + +--*/ + +#include "asmmacro.h" + + .text + +/*++ + +Routine Description: + + This routine transposes elements from the source matrix to the destination + packed buffer. + + 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 + rows in the destination packed buffer. + +Arguments: + + D (a0) - Supplies the address of the destination packed buffer. + + B (a1) - Supplies the address of the source matrix. + + ldb (a2) - Supplies the number of elements per row of the source matrix. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasSgemmTransposePackB16x4LSX + addi.d $sp, $sp, -64 + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + slli.d $a2, $a2, 2 # convert ldb to bytes + ori $a3, $zero, 4 # transpose four 4x4 blocks + vxor.v $vr7, $vr7, $vr7 +.LTransposeBlockLoop: + slli.d $s0, $a2, 1 + add.d $s1, $a1, $s0 + vld $vr0, $a1, 0 + vldx $vr1, $a1, $a2 + vld $vr2, $s1, 0 + vldx $vr3, $s1, $a2 + + vor.v $vr4, $vr0, $vr7 + vilvl.w $vr4, $vr1, $vr4 + vilvh.w $vr0, $vr1, $vr0 + vor.v $vr5, $vr2, $vr7 + vilvl.w $vr5, $vr3, $vr5 + vilvh.w $vr2, $vr3, $vr2 + vor.v $vr1, $vr4, $vr7 + vilvl.d $vr1, $vr5, $vr1 + vilvh.d $vr4, $vr5, $vr4 + vor.v $vr3, $vr0, $vr7 + vilvl.d $vr3, $vr2, $vr3 + vilvh.d $vr0, $vr2, $vr0 + vst $vr1, $a0, 0 + vst $vr4, $a0, 0x40 + vst $vr3, $a0, 0x80 + vst $vr0, $a0, 0xc0 + addi.d $a0, $a0, 0x10 + slli.d $s0, $a2, 1 + add.d $a1, $s0, $s1 + addi.d $a3, $a3, -1 + bnez $a3, .LTransposeBlockLoop + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + addi.d $sp, $sp, 64 + jr $ra + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4Lasx.S b/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4Lasx.S new file mode 100644 index 0000000000000..e617419989c4d --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4Lasx.S @@ -0,0 +1,126 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmTransposePackB16x4Lasx.s + +Abstract: + + This module implements routines for packing buffers for the single precision + matrix/matrix multiply operation (SGEMM). + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" + + .text + +/*++ + +Macro Description: + + 4 columns of 8 rows from the source matrix are transposed to 8 columns of 4 + rows in the destination packed buffer. + +Arguments: + + StoreOffset - Supplies the relative byte offset into the destination packed + buffer. + +Implicit Arguments: + + a0 - Supplies the address of the destination packed buffer. + + a1 - Supplies the address of the source matrix. + + a2 - Supplies the number of elements per row of the source matrix. + +--*/ + + .macro TransposePackB8x4BlockLasx StoreOffset + +// +// Load 4 columns from 8 rows of the source matrix into the lower and upper +// halves of 4 XR registers. +// + + add.d $t0, $a2, $a2 + add.d $t6, $a1, $t0 + vld $vr0, $a1, 0 + vldx $vr1, $a1, $a2 + add.d $t0, $a2, $a2 + add.d $a1, $t6, $t0 + vld $vr2, $t6, 0 + vldx $vr3, $t6, $a2 + add.d $t0, $a2, $a2 + add.d $t6, $a1, $t0 + + vld $vr4, $a1, 0 + xvpermi.q $xr0, $xr4, 0x2 + vldx $vr5, $a1, $a2 + xvpermi.q $xr1, $xr5, 0x2 + vld $vr4, $t6, 0 + xvpermi.q $xr2, $xr4, 0x2 + vldx $vr5, $t6, $a2 + xvpermi.q $xr3, $xr5, 0x2 + +// +// Transpose the lower and upper halves of the 4 XR registers as two 4x4 +// matrices and store the output to the destination packed buffer. +// + + xvilvl.w $xr4, $xr1, $xr0 + xvilvh.w $xr5, $xr1, $xr0 + xvilvl.w $xr0, $xr3, $xr2 + xvilvh.w $xr1, $xr3, $xr2 + xvilvl.d $xr2, $xr0, $xr4 + xvilvh.d $xr3, $xr0, $xr4 + xvst $xr2, $a0, \StoreOffset\() + xvst $xr3, $a0, 0x40+\StoreOffset\() + xvilvl.d $xr0, $xr1, $xr5 + xvilvh.d $xr4, $xr1, $xr5 + xvst $xr0, $a0, 0x80+\StoreOffset\() + xvst $xr4, $a0, 0xc0+\StoreOffset\() + + .endm + +/*++ + +Routine Description: + + This routine transposes elements from the source matrix to the destination + packed buffer. + + 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 + rows in the destination packed buffer. + +Arguments: + + D (a0) - Supplies the address of the destination packed buffer. + + B (a1) - Supplies the address of the source matrix. + + ldb (a2) - Supplies the number of elements per row of the source matrix. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasSgemmTransposePackB16x4Lasx + + slli.d $a2, $a2, 2 # convert ldb to bytes + TransposePackB8x4BlockLasx 0*4 + add.d $t0, $a2, $a2 + add.d $a1, $t0, $t6 + TransposePackB8x4BlockLasx 8*4 + jr $ra + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SoftmaxKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/SoftmaxKernelLasx.S new file mode 100644 index 0000000000000..aaaa3cbf9138d --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SoftmaxKernelLasx.S @@ -0,0 +1,357 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SoftmaxKernelLasx.s + +Abstract: + + This module implements the kernels for the single precision softmax + operation. + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" + + .text + +/*++ + +Routine Description: + + This routine implements a vectorized kernel to find the maximum value of + the supplied buffer. + +Arguments: + + Input (a0) - Supplies the input buffer. + + N (a1) - Supplies the number of elements to process. + +Return Value: + + Returns the maximum value of the supplied buffer. + +--*/ + + FUNCTION_ENTRY MlasReduceMaximumF32KernelLasx + addi.d $sp, $sp, -32 + + la.global $t0, MlasMinimumF32Value + ld.w $t0, $t0, 0 + xvreplgr2vr.w $xr0, $t0 + beqz $a1, .LReduceMaximum.ExitKernel + ori $t0, $zero, 8 + bltu $a1, $t0, .LReduceMaximum.ProcessRemainingCountBy1 + ori $t1, $zero, 32 + bltu $a1, $t1, .LReduceMaximum.ProcessRemainingCountBy8 + xvreplgr2vr.w $xr16, $zero + xvor.v $xr1, $xr0, $xr16 + xvor.v $xr2, $xr0, $xr16 + xvor.v $xr3, $xr0, $xr16 + +.LReduceMaximum.ProcessRemainingCountBy32: + xvld $xr16, $a0, 0 + xvfmax.s $xr0, $xr0, $xr16 + xvld $xr16, $a0, 8*4 + xvfmax.s $xr1, $xr1, $xr16 + addi.d $a1, $a1, -0x20 + xvld $xr16, $a0, 16*4 + xvfmax.s $xr2, $xr2, $xr16 + xvld $xr16, $a0, 24*4 + xvfmax.s $xr3, $xr3, $xr16 + addi.d $a0, $a0, 32*4 # advance input by 32 elements + ori $t1, $zero, 32 + bgeu $a1, $t1, .LReduceMaximum.ProcessRemainingCountBy32 + xvfmax.s $xr0, $xr0, $xr1 + xvfmax.s $xr2, $xr2, $xr3 + xvfmax.s $xr0, $xr0, $xr2 + +.LReduceMaximum.ProcessRemainingCountBy8: + ori $t1, $zero, 8 + bltu $a1, $t1, .LReduceMaximum.ProcessRemainingCountLessThan8 + xvld $xr16, $a0, 0 + xvfmax.s $xr0, $xr0, $xr16 + addi.d $a1, $a1, -8 + addi.d $a0, $a0, 8*4 + b .LReduceMaximum.ProcessRemainingCountBy8 + +.LReduceMaximum.ProcessRemainingCountLessThan8: + xvst $xr0, $sp, 0 + vld $vr1, $sp, 0x10 + vld $vr0, $sp, 0 + vfmax.s $vr0, $vr0, $vr1 + vshuf4i.w $vr1, $vr0, 0xee + vfmax.s $vr0, $vr0, $vr1 + vshuf4i.w $vr1, $vr0, 0x55 + vfmax.s $vr0, $vr0, $vr1 + beqz $a1, .LReduceMaximum.ExitKernel + +.LReduceMaximum.ProcessRemainingCountBy1: + vld $vr16, $a0, 0 + vfmax.s $vr0, $vr0, $vr16 + addi.d $a0, $a0, 4 # advance input by 1 element + addi.d $a1, $a1, -1 + bnez $a1, .LReduceMaximum.ProcessRemainingCountBy1 + +.LReduceMaximum.ExitKernel: + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 + addi.d $sp, $sp, 32 + jr $ra + +/*++ + +Routine Description: + + This routine implements a vectorized kernel to produce the final output for + the softmax operation. + +Arguments: + + Output (a0) - Supplies the output buffer. + + N (a1) - Supplies the number of elements to process. + + Parameters (a2) - Supplies an array containing the scale value. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasComputeSoftmaxOutputF32KernelLasx + + ld.w $t0, $a2, 0 + xvreplgr2vr.w $xr4, $t0 + ori $t1, $zero, 0x20 + bltu $a1, $t1, .LComputeSoftmaxOutput.ProcessRemainingCountBy8 + +.LComputeSoftmaxOutput.ProcessRemainingCountBy32: + xvld $xr16, $a0, 0 + xvfmul.s $xr0, $xr4, $xr16 + xvld $xr16, $a0, 8*4 + xvfmul.s $xr1, $xr4, $xr16 + addi.d $a1, $a1, -0x20 + xvld $xr16, $a0, 16*4 + xvfmul.s $xr2, $xr4, $xr16 + xvld $xr16, $a0, 24*4 + xvfmul.s $xr3, $xr4, $xr16 + xvst $xr0, $a0, 0 + xvst $xr1, $a0, 8*4 + xvst $xr2, $a0, 16*4 + xvst $xr3, $a0, 24*4 + addi.d $a0, $a0, 0x80 # advance output by 32 elements + bgeu $a1, $t1, .LComputeSoftmaxOutput.ProcessRemainingCountBy32 + +.LComputeSoftmaxOutput.ProcessRemainingCountBy8: + ori $t2, $zero, 8 + bltu $a1, $t2, .LComputeSoftmaxOutput.ProcessRemainingCountLessThan8 + xvld $xr16, $a0, 0 + xvfmul.s $xr0, $xr4, $xr16 + addi.d $a1, $a1, -8 + xvst $xr0, $a0, 0 + addi.d $a0, $a0, 8*4 # advance output by 8 elements + b .LComputeSoftmaxOutput.ProcessRemainingCountBy8 + +.LComputeSoftmaxOutput.ProcessRemainingCountLessThan8: + beqz $a1, .LComputeSoftmaxOutput.ExitKernel + +.LComputeSoftmaxOutput.ProcessRemainingCountBy1: + fld.s $f16, $a0, 0 + fmul.s $f0, $f4, $f16 + fst.s $f0, $a0, 0 + addi.d $a0, $a0, 4 # advance output by 1 element + addi.d $a1, $a1, -1 + bnez $a1, .LComputeSoftmaxOutput.ProcessRemainingCountBy1 + +.LComputeSoftmaxOutput.ExitKernel: + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 + jr $ra + +/*++ + +Routine Description: + + This routine implements a vectorized kernel to produce the final output for + the log softmax operation. + +Arguments: + + Input (a0) - Supplies the output buffer. + + Output (a1) - Supplies the output buffer. + + N (a2) - Supplies the number of elements to process. + + Parameters (a3) - Supplies an array containing the negative maximum and + logarithm values. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasComputeLogSoftmaxOutputF32KernelLasx + + ld.w $t0, $a3, 0 + ld.w $t1, $a3, 4 + ori $t2, $zero, 0x20 + xvreplgr2vr.w $xr4, $t0 # broadcast negative minimum value + xvreplgr2vr.w $xr5, $t1 # broadcast log(SumExp) + bltu $a2, $t2, .LComputeLogSoftmaxOutput.ProcessRemainingCountBy8 + +.LComputeLogSoftmaxOutput.ProcessRemainingCountBy32: + xvld $xr16, $a0, 0 + xvfadd.s $xr0, $xr4, $xr16 + xvld $xr16, $a0, 0x20 + xvfadd.s $xr1, $xr4, $xr16 + addi.d $a2, $a2, -0x20 + xvld $xr16, $a0, 0x40 + xvfadd.s $xr2, $xr4, $xr16 + xvld $xr16, $a0, 0x60 + xvfadd.s $xr3, $xr4, $xr16 + addi.d $a0, $a0, 0x80 # advance input by 32 elements + xvfsub.s $xr0, $xr0, $xr5 # do as two steps for numeric stability + xvfsub.s $xr1, $xr1, $xr5 # do as two steps for numeric stability + xvfsub.s $xr2, $xr2, $xr5 # do as two steps for numeric stability + xvfsub.s $xr3, $xr3, $xr5 # do as two steps for numeric stability + xvst $xr0, $a1, 0 + xvst $xr1, $a1, 0x20 + xvst $xr2, $a1, 0x40 + xvst $xr3, $a1, 0x60 + addi.d $a1, $a1, 0x80 # advance output by 32 elements + bgeu $a2, $t2, .LComputeLogSoftmaxOutput.ProcessRemainingCountBy32 + +.LComputeLogSoftmaxOutput.ProcessRemainingCountBy8: + ori $t3, $zero, 8 + bltu $a2, $t3, .LComputeLogSoftmaxOutput.ProcessRemainingCountLessThan8 + xvld $xr16, $a0, 0 + xvfadd.s $xr0, $xr4, $xr16 + addi.d $a0, $a0, 0x20 + xvfsub.s $xr0, $xr0, $xr5 + addi.d $a2, $a2, -8 + xvst $xr0, $a1, 0 + addi.d $a1, $a1, 0x20 # advance output by 8 elements + b .LComputeLogSoftmaxOutput.ProcessRemainingCountBy8 + +.LComputeLogSoftmaxOutput.ProcessRemainingCountLessThan8: + beqz $a2, .LComputeLogSoftmaxOutput.ExitKernel + +.LComputeLogSoftmaxOutput.ProcessRemainingCountBy1: + fld.s $f16, $a0, 0 + fadd.s $f0, $f4, $f16 + + addi.d $a0, $a0, 4 + fsub.s $f0, $f0, $f5 + fst.s $f0, $a1, 0 + + addi.d $a1, $a1, 4 + addi.d $a2, $a2, -1 + bnez $a2, .LComputeLogSoftmaxOutput.ProcessRemainingCountBy1 + +.LComputeLogSoftmaxOutput.ExitKernel: + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 + jr $ra + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLSX.S b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLSX.S new file mode 100644 index 0000000000000..96bda3bb12c6f --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLSX.S @@ -0,0 +1,460 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SpoolKernelLSX.s + +Abstract: + + This module implements the kernels for the single precision pooling + operation. + + This implementation uses LSX instructions. + +--*/ + +#define SP_SIZE 32*8 +#define InputBase_arg SP_SIZE+0*8 +#define InputWidth_arg SP_SIZE+1*8 +#define DilatedInputWidth_arg SP_SIZE+2*8 +#define OutputCountLeftPad_arg SP_SIZE+3*8 +#define OutputCount_arg SP_SIZE+4*8 +#define OutputCountRightPad_arg SP_SIZE+5*8 + + .macro FUNCTION_ENTRY FunctionName + + .p2align 4 + .globl \FunctionName\() + .type \FunctionName\(),@function +\FunctionName\(): + + .endm + + + .text + +/*++ + +Macro Description: + + This macro generates code to initialize registers used across the kernel. + +Arguments: + + PoolingType - Supplies the pooling type string. + +--*/ + + .macro InitializeKernel PoolingType + +.ifeqs "\PoolingType\()","Maximum" + li.w $s0, 0xFF7FFFFF + vreplgr2vr.w $vr5, $s0 +.endif + +.ifeqs "\PoolingType\()","AverageIncludePad" + vreplgr2vr.w $vr5, $a5 + vffint.s.w $vr5, $vr5 +.endif + + .endm +/*++ + +Macro Description: + + This macro generates the common prologue code for the pooling kernels. + +Arguments: + + PoolingType - Supplies the pooling type string. + +--*/ + + .macro SpoolKernelEntry PoolingType + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $s3, $sp, 3*8 + st.d $s4, $sp, 4*8 + st.d $ra, $sp, 5*8 + fst.d $f24,$sp, 6*8 + + InitializeKernel \PoolingType\() + # move InputStride to s8 + or $t8, $a4, $r0 + # move StrideWidth to a4 + or $a4, $a2, $r0 + # move DilationWidth to a5 + or $a5, $a3, $r0 + # move Output to a2 + or $a2, $a1, $r0 + + .endm + +/*++ + +Macro Description: + + This macro generates the common epilogue code for the pooling kernels. + +Arguments: + + None. + +--*/ + + .macro SpoolKernelExit + + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $s3, $sp, 3*8 + ld.d $s4, $sp, 4*8 + ld.d $ra, $sp, 5*8 + fld.d $f24,$sp, 6*8 + + addi.d $sp, $sp, SP_SIZE + jr $ra + + .endm + + +/*++ + +Macro Description: + + This macro generates code to clear the pooling intermediates. + + For PoolingType==Maximum, the pooling intermediates are set to the minimum + float value. Otherwise, the pooling intermediates are cleared to zero. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + vr0-vr1 - Supplies the pooling intermediates. + + vr2 - Supplies a vector containing the minimum float value broadcasted, + if PoolingType==Maximum. + +--*/ + + .macro ClearBlock PoolingType, OutputCount + +.ifeqs "\PoolingType\()","Maximum" + vor.v $vr0, $vr5, $vr5 + vor.v $vr1, $vr5, $vr5 +.else + vxor.v $vr0, $vr0, $vr0 + vxor.v $vr1, $vr1, $vr1 +.endif + +.ifeqs "\PoolingType\()","AverageExcludePad" + xor $a1, $a1, $a1 # reset valid block counter +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to sample the input buffer and update the pooling + intermediates as appropriate. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a3 - Supplies the address of the input buffer. + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + a4 - Supplies the StrideWidth parameter (see function description). + + vr0-vr1 - Supplies the pooling intermediates. + +--*/ + + .macro ComputeBlock PoolingType, OutputCount + +.ifeqs "\PoolingType\()","Maximum" + vld $vr24, $a3, 0 + vfmax.s $vr0, $vr0, $vr24 + vld $vr24, $a3, 16 + vfmax.s $vr1, $vr1, $vr24 +.else + vld $vr24, $a3, 0 + vfadd.s $vr0, $vr0, $vr24 + vld $vr24, $a3, 16 + vfadd.s $vr1, $vr1, $vr24 +.endif + +.ifeqs "\PoolingType\()","AverageExcludePad" + # increment valid block counter + addi.d $a1, $a1, 1 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to process and store the pooling intermediates. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a2 - Supplies the address of the output buffer. + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + vr0-vr1 - Supplies the pooling intermediates. + + vr5 - Supplies the kernel size computed by InitializeKernel, if + PoolingType=AverageExcludePad, else the actual kernel size, if + PoolingType=AverageIncludePad. + +--*/ + + .macro PostProcessBlock PoolingType, OutputCount + +// +// If PoolingType=AverageExcludePad, divide the sum by the number of non-padding +// blocks. +// + +.ifeqs "\PoolingType\()","AverageExcludePad" + # convert valid block counter + vreplgr2vr.w $vr4, $a1 + vffint.s.w $vr4, $vr4 + vfdiv.s $vr0, $vr0, $vr4 + vfdiv.s $vr1, $vr1, $vr4 +.endif + +// +// If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. +// + +.ifeqs "\PoolingType\()","AverageIncludePad" + vfdiv.s $vr0, $vr0, $vr5 + vfdiv.s $vr1, $vr1, $vr5 +.endif + +// +// Store the output block in the output buffer. +// + + vst $vr0, $a2, 0 + vst $vr1, $a2, 16 + # advance output by 1 nchw8c block + addi.d $a2, $a2, 8*4 + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute pooling for a vector of input blocks + to produce a matrix of output blocks. + + OutputCount=1 generates special case code to handle padding blocks. All + other output counts assume no padding. + +Arguments: + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a2 - Supplies the address of the output buffer. + + a4 - Supplies the StrideWidth parameter (see function description). + + a5 - Supplies the DilationWidth parameter (see function description). + + s8 - Supplies the InputStride parameter (see function description). + +--*/ + + .macro ProcessOutputCountN KernelFrame, PoolingType, OutputCount + + move $a3, $a0 + move $t1, $a6 + move $t2, $a7 +.if \OutputCount\() == 1 + ld.d $t3, $sp, InputBase_arg + ld.d $t4, $sp, InputWidth_arg + sub.d $t3, $r0, $t3 # keep negative for lea usage below +.endif + ClearBlock \PoolingType\(), \OutputCount\() + beqz $t1, .L\PoolingType\().\OutputCount\().HandlePostProcessing + +.L\PoolingType\().\OutputCount\().ProcessNextRow: + or $t6, $t2, $t2 + +.L\PoolingType\().\OutputCount\().ProcessNextColumn: +.if \OutputCount\() == 1 + # (Input - InputBase) >= InputWidth? + add.d $t7, $a3, $t3 + bgeu $t7, $t4, .L\PoolingType\().\OutputCount\().SkipOverPadding +.endif + ComputeBlock \PoolingType\(), \OutputCount\() + +.L\PoolingType\().\OutputCount\().SkipOverPadding: + add.d $a3, $a3, $a5 # advance input by dilation width + # decrement columns remaining + addi.d $t6, $t6, -1 + bnez $t6, .L\PoolingType\().\OutputCount\().ProcessNextColumn + add.d $a3, $a3, $t8 # advance input to next row +.if \OutputCount\() == 1 + ld.d $s0, $sp, DilatedInputWidth_arg + # advance input base to next row + sub.d $t3, $t3, $s0 +.endif + addi.d $t1, $t1, -1 + bnez $t1, .L\PoolingType\().\OutputCount\().ProcessNextRow + +.L\PoolingType\().\OutputCount\().HandlePostProcessing: + PostProcessBlock \PoolingType\(), \OutputCount\() + + .endm +/*++ + +Macro Description: + + This macro generates code for the inner pooling kernel. + +Arguments: + + PoolingType - Supplies the pooling type string. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro SpoolKernelFunction PoolingType, Isa + +/*++ + +Routine Description: + + This routine is the inner kernel to compute pooling for the elements of an + output row for a set of filter rows. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Output (a1) - Supplies the address of the output buffer. + + StrideWidth (a2) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a3) - Supplies the length in bytes of the blocked dilation + width. + + InputStride (a4) - Supplies the length in bytes to advance the input buffer to + the next input row. + + ActualKernelSize (a5) - Supplies the size of the kernel based on the original + kernel dimensions, used for PoolingType=AverageIncludePad. + + KernelHeight (a6) - Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (a7) - Supplies the width of the kernel to apply. + + InputBase (0)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (1*8)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (2*8)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (3*8)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (4*8)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (5*8)- Supplies the number of output elements that include + one or more padding elements from the right edge. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasPool\PoolingType\()FloatKernel\Isa\() + SpoolKernelEntry \PoolingType\() + + ld.d $s0, $sp, OutputCountLeftPad_arg + ld.d $s1, $sp, OutputCount_arg + add.d $t0, $s0, $s1 + ld.d $s0, $sp, OutputCountRightPad_arg + add.d $t0, $t0, $s0 + beqz $t0, .L\PoolingType\().ExitKernel + +.L\PoolingType\().ProcessNextOutputCount: + ProcessOutputCountN .LSpoolKernelFrame, \PoolingType\(), 1 + add.d $a0, $a0, $a4 + addi.d $t0, $t0, -1 + bnez $t0, .L\PoolingType\().ProcessNextOutputCount + +.L\PoolingType\().ExitKernel: + SpoolKernelExit + + .endm + +// +// Generate the pooling kernels. +// + + SpoolKernelFunction Maximum, LSX + SpoolKernelFunction AverageExcludePad, LSX + SpoolKernelFunction AverageIncludePad, LSX + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasx.S new file mode 100644 index 0000000000000..6e5f0136cd4ab --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasx.S @@ -0,0 +1,238 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SpoolKernelLasx.s + +Abstract: + + This module implements the kernels for the single precision pooling + operation. + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" +#include "SpoolKernelLasxCommon.h" + + .text + +/*++ + +Macro Description: + + This macro generates code to initialize registers used across the kernel. + +Arguments: + + PoolingType - Supplies the pooling type string. + +Implicit Arguments: + + a5 - Supplies the ActualKernelSize parameter (see function description). + +--*/ + + .macro InitializeKernel PoolingType + +.ifeqs "\PoolingType\()","Maximum" + li.w $s0, 0xFF7FFFFF + xvreplgr2vr.w $xr5, $s0 +.else + xvxor.v $xr5, $xr5, $xr5 +.ifeqs "\PoolingType\()","AverageExcludePad" + move $t6, $a6 + mul.d $t6, $t6, $a7 + xvreplgr2vr.w $xr5, $t6 +.else + xvreplgr2vr.w $xr5, $a5 +.endif + xvffint.s.w $xr5, $xr5 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to clear the pooling intermediates. + + For PoolingType==Maximum, the pooling intermediates are set to the minimum + float value. Otherwise, the pooling intermediates are cleared to zero. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + xr0-xr2 - Supplies the pooling intermediates. + + xr5 - Supplies a vector containing the minimum float value broadcasted, + if PoolingType==Maximum. + +--*/ + + .macro ClearBlock PoolingType, OutputCount + +.ifeqs "\PoolingType\()","Maximum" + EmitIfCountGE \OutputCount\(), 1, "xvor.v $xr0, $xr5, $xr5" + EmitIfCountGE \OutputCount\(), 2, "xvor.v $xr1, $xr5, $xr5" + EmitIfCountGE \OutputCount\(), 3, "xvor.v $xr2, $xr5, $xr5" +.else + EmitIfCountGE \OutputCount\(), 1, "xvxor.v $xr0, $xr0, $xr0" + EmitIfCountGE \OutputCount\(), 2, "xvxor.v $xr1, $xr1, $xr1" + EmitIfCountGE \OutputCount\(), 3, "xvxor.v $xr2, $xr2, $xr2" +.endif + +.ifeqs "\PoolingType\()","AverageExcludePad" +.if \OutputCount\() == 1 + xor $a1, $a1, $a1 # reset valid block counter +.endif +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to sample the input buffer and update the pooling + intermediates as appropriate. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a3 - Supplies the address of the input buffer. + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + a4 - Supplies the StrideWidth parameter (see function description). + + xr0-xr2 - Supplies the pooling intermediates. + +--*/ + + .macro ComputeBlock PoolingType, OutputCount + +.ifeqs "\PoolingType\()","Maximum" + EmitIfCountGE \OutputCount\(), 1, "xvld $xr16, $a3, 0" + EmitIfCountGE \OutputCount\(), 1, "xvfmax.s $xr0, $xr0, $xr16" + EmitIfCountGE \OutputCount\(), 2, "xvldx $xr16, $a3, $a4" + EmitIfCountGE \OutputCount\(), 2, "xvfmax.s $xr1, $xr1, $xr16" + EmitIfCountGE \OutputCount\(), 3, "slli.d $s0, $a4, 1" + EmitIfCountGE \OutputCount\(), 3, "xvldx $xr16, $a3, $s0" + EmitIfCountGE \OutputCount\(), 3, "xvfmax.s $xr2, $xr2, $xr16" +.else + EmitIfCountGE \OutputCount\(), 1, "xvld $xr16, $a3, 0" + EmitIfCountGE \OutputCount\(), 1, "xvfadd.s $xr0, $xr0, $xr16" + EmitIfCountGE \OutputCount\(), 2, "xvldx $xr16, $a3, $a4" + EmitIfCountGE \OutputCount\(), 2, "xvfadd.s $xr1, $xr1, $xr16" + EmitIfCountGE \OutputCount\(), 3, "slli.d $s0, $a4, 1" + EmitIfCountGE \OutputCount\(), 3, "xvldx $xr16, $a3, $s0" + EmitIfCountGE \OutputCount\(), 3, "xvfadd.s $xr2, $xr2, $xr16" +.endif + +.ifeqs "\PoolingType\()","AverageExcludePad" +.if \OutputCount\() == 1 + addi.d $a1, $a1, 1 # increment valid block counter +.endif +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to process and store the pooling intermediates. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a2 - Supplies the address of the output buffer. + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + xr0-xr2 - Supplies the pooling intermediates. + + xr5 - Supplies the kernel size computed by InitializeKernel, if + PoolingType=AverageExcludePad, else the actual kernel size, if + PoolingType=AverageIncludePad. + +--*/ + + .macro PostProcessBlock PoolingType, OutputCount + +// +// If PoolingType=AverageExcludePad, divide the sum by the number of non-padding +// blocks. OutputCount=1 generates code to count the number of blocks accessed by +// ComputeBlock. Other cases use the kernel size computed by InitializeKernel. +// + +.ifeqs "\PoolingType\()","AverageExcludePad" +.if \OutputCount\() == 1 + xvxor.v $xr4, $xr4, $xr4 + xvreplgr2vr.w $xr4, $a1 + xvffint.s.w $xr4, $xr4 + xvfdiv.s $xr0, $xr0, $xr4 +.else + EmitIfCountGE \OutputCount\(), 1, "xvfdiv.s $xr0, $xr0, $xr5" + EmitIfCountGE \OutputCount\(), 2, "xvfdiv.s $xr1, $xr1, $xr5" + EmitIfCountGE \OutputCount\(), 3, "xvfdiv.s $xr2, $xr2, $xr5" +.endif +.endif + +// +// If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. +// + +.ifeqs "\PoolingType\()","AverageIncludePad" + EmitIfCountGE \OutputCount\(), 1, "xvfdiv.s $xr0, $xr0, $xr5" + EmitIfCountGE \OutputCount\(), 2, "xvfdiv.s $xr1, $xr1, $xr5" + EmitIfCountGE \OutputCount\(), 3, "xvfdiv.s $xr2, $xr2, $xr5" +.endif + +// +// Store the output block in the output buffer. +// + + EmitIfCountGE \OutputCount\(), 1, "xvst $xr0, $a2, 0" + EmitIfCountGE \OutputCount\(), 2, "xvst $xr1, $a2, 0x20" + EmitIfCountGE \OutputCount\(), 3, "xvst $xr2, $a2, 0x40" + add_immed $a2,\OutputCount\()*8*4 # advance output by N nchw8c blocks + + .endm + +// +// Generate the pooling kernels. +// + + SpoolKernelFunction Maximum, Lasx + SpoolKernelFunction AverageExcludePad, Lasx + SpoolKernelFunction AverageIncludePad, Lasx + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasxCommon.h new file mode 100644 index 0000000000000..066c75d34f3f9 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasxCommon.h @@ -0,0 +1,311 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SpoolKernelasxCommon.h + +Abstract: + + This module contains common kernel macros and structures for the single + precision pooling operation for the Lasx kernels. + +--*/ + +// +// Stack frame layout for the pooling kernels. +// + +#define SP_SIZE 8*8 +#define InputBase_arg SP_SIZE+0*8 +#define InputWidth_arg SP_SIZE+1*8 +#define DilatedInputWidth_arg SP_SIZE+2*8 +#define OutputCountLeftPad_arg SP_SIZE+3*8 +#define OutputCount_arg SP_SIZE+4*8 +#define OutputCountRightPad_arg SP_SIZE+5*8 +/*++ + +Macro Description: + + This macro generates the common prologue code for the pooling kernels. + +Arguments: + + PoolingType - Supplies the pooling type string. + +--*/ + + .macro SpoolKernelEntry PoolingType + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0 + st.d $s1, $sp, 1*8 + fst.d $f16, $sp, 2*8 + st.d $ra, $sp, 5*8 + + InitializeKernel \PoolingType\() + move $t8, $a4 + move $a4, $a2 + move $a5, $a3 + move $a2, $a1 + + .endm + +/*++ + +Macro Description: + + This macro generates the common epilogue code for the pooling kernels. + +Arguments: + + None. + +--*/ + + .macro SpoolKernelExit + + ld.d $s0, $sp, 0 + ld.d $s1, $sp, 1*8 + fld.d $f16, $sp, 2*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jr $ra + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute pooling for a vector of input blocks + to produce a matrix of output blocks. + + OutputCount=1 generates special case code to handle padding blocks. All + other output counts assume no padding. + +Arguments: + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a2 - Supplies the address of the output buffer. + + a4 - Supplies the StrideWidth parameter (see function description). + + a5 - Supplies the DilationWidth parameter (see function description). + + t8 - Supplies the InputStride parameter (see function description). + +--*/ + + .macro ProcessOutputCountN KernelFrame, PoolingType, OutputCount + + move $a3, $a0 + move $t1, $a6 + move $t2, $a7 +.if \OutputCount\() == 1 + ld.d $t3, $sp, InputBase_arg + ld.d $t4, $sp, InputWidth_arg + sub.d $t3, $zero, $t3 +.endif + ClearBlock \PoolingType\(), \OutputCount\() + beqz $t1, .L\PoolingType\().\OutputCount\().HandlePostProcessing + +.L\PoolingType\().\OutputCount\().ProcessNextRow: + move $t6, $t2 + +.L\PoolingType\().\OutputCount\().ProcessNextColumn: +.if \OutputCount\() == 1 + add.d $t7, $a3, $t3 # compute (Input - InputBase) + # (Input - InputBase) >= InputWidth? + bgeu $t7, $t4, .L\PoolingType\().\OutputCount\().SkipOverPadding +.endif + ComputeBlock \PoolingType\(), \OutputCount\() + +.L\PoolingType\().\OutputCount\().SkipOverPadding: + add.d $a3, $a3, $a5 # advance input by dilation width + addi.d $t6, $t6, -1 # decrement columns remaining + bnez $t6, .L\PoolingType\().\OutputCount\().ProcessNextColumn + add.d $a3, $a3, $t8 # advance input to next row +.if \OutputCount\() == 1 + ld.d $s0, $sp, DilatedInputWidth_arg + sub.d $t3, $t3, $s0 + # advance input base to next row +.endif + addi.d $t1, $t1, -1 + bnez $t1, .L\PoolingType\().\OutputCount\().ProcessNextRow + +.L\PoolingType\().\OutputCount\().HandlePostProcessing: + PostProcessBlock \PoolingType\(), \OutputCount\() + + .endm +/*++ + +Macro Description: + + This macro generates code for the inner pooling kernel. + +Arguments: + + PoolingType - Supplies the pooling type string. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro SpoolKernelFunction PoolingType, Isa + +/*++ + +Routine Description: + + This routine is the inner kernel to compute pooling for the elements of an + output row for a set of filter rows. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Output (a1) - Supplies the address of the output buffer. + + StrideWidth (a2) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a3) - Supplies the length in bytes of the blocked dilation + width. + + InputStride (a4) - Supplies the length in bytes to advance the input buffer to + the next input row. + + ActualKernelSize (a5) - Supplies the size of the kernel based on the original + kernel dimensions, used for PoolingType=AverageIncludePad. + + KernelHeight (a6) - Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (a7)- Supplies the width of the kernel to apply. + + InputBase (sp + 0)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp + 0x8)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp + 0x10)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp + 0x18)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp + 0x20)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp + 0x28)- Supplies the number of output elements that include + one or more padding elements from the right edge. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasPool\PoolingType\()FloatKernel\Isa\() + + SpoolKernelEntry \PoolingType\() + +.L\PoolingType\().ProcessOutputCountLeftPad: + ld.d $t0, $sp, OutputCountLeftPad_arg + + beqz $t0, .L\PoolingType\().ProcessOutputCount + bl MlasPool\PoolingType\()FloatSingle\Isa\() + +.L\PoolingType\().ProcessOutputCount: + ld.d $t0, $sp, OutputCount_arg + li.d $s0, 3 + bltu $t0, $s0, .L\PoolingType\().ProcessRemainingOutputCount + +.L\PoolingType\().ProcessNextOutputCountBy3: + ProcessOutputCountN .LSpoolKernelFrame, \PoolingType\(), 3 + slli.d $s0, $a4, 1 + add.d $t6, $s0, $a4 + add.d $a0, $a0, $t6 # advance input by 3 elements + addi.d $t0, $t0, -3 + li.d $s0, 3 + bgeu $t0, $s0, .L\PoolingType\().ProcessNextOutputCountBy3 + +.L\PoolingType\().ProcessRemainingOutputCount: + +.L\PoolingType\().ProcessOutputCountRightPad: + ld.d $s0, $sp, OutputCountRightPad_arg + add.d $t0, $t0, $s0 + beqz $t0, .L\PoolingType\().ExitKernel + bl MlasPool\PoolingType\()FloatSingle\Isa\() + +.L\PoolingType\().ExitKernel: + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 + SpoolKernelExit + +// +// Generate out-of-band helpers for handling output blocks involving padding. +// + +MlasPool\PoolingType\()FloatSingle\Isa\(): + st.d $ra, $sp, 6*8 +loopMlasPool\PoolingType\()FloatSingle\Isa\(): + ProcessOutputCountN .LSpoolKernelSingleFrame, \PoolingType\(), 1 + add.d $a0, $a0, $a4 # advance input by 1 element + addi.d $t0, $t0, -1 # decrement output count remaining + bnez $t0, loopMlasPool\PoolingType\()FloatSingle\Isa\() + ld.d $ra, $sp, 6*8 + jr $ra + + .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/asmmacro.h b/onnxruntime/core/mlas/lib/loongarch64/asmmacro.h new file mode 100644 index 0000000000000..837aca77dd883 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/asmmacro.h @@ -0,0 +1,144 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + asmmacro.h + +Abstract: + + This module implements common macros for the assembly modules. + +--*/ + +#define C_UNDERSCORE(symbol) symbol + +.macro vmove dst src + vand.v \dst, \src, \src +.endm + +/*++ + +Macro Description: + + This macro emits the assembler directives to annotate a new function. + +Arguments: + + FunctionName - Supplies the name of the function. + +--*/ + + .macro FUNCTION_ENTRY FunctionName + .align 2 + .globl \FunctionName\() + .type \FunctionName\(),@function +\FunctionName\(): + + .endm + +/*++ + +Macro Description: + + This macro generates an optimization for "add reg,128" which can instead + be encoded as "sub reg,-128" to reduce code size by using a signed 8-bit + value. + +Arguments: + + Register - Supplies the register to be added to. + + Immediate - Supplies the immediate to add to the register. + +--*/ + + .macro add_immed Register, Immediate + +.if (\Immediate\() != 128) + addi.d \Register\(),\Register\(),\Immediate\() +.else + addi.d \Register\(),\Register\(),\Immediate\() # smaller encoding +.endif + + .endm + +/*++ + +Macro Description: + + This macro conditionally emits the statement if Count is greater than or + equal to Value. + +Arguments: + + Count - Supplies the variable used in the comparison. + + Value - Supplies the static used in the comparison. + + Statement - Supplies the statement to conditionally emit. + +--*/ + + .macro EmitIfCountGE Count1, Value1, Statement + +.if (\Count1\() >= \Value1\()) + \Statement\() +.endif + + .endm + +/*++ + +Macro Description: + + This macro conditionally emits the statement if Count1 is greater than or + equal to Value1 and Count2 is greater than or equal to Value2. + +Arguments: + + Count1 - Supplies the variable used in the comparison. + + Value1 - Supplies the static used in the comparison. + + Count2 - Supplies the variable used in the comparison. + + Value2 - Supplies the static used in the comparison. + + Statement - Supplies the statement to conditionally emit. + +--*/ + + .macro EmitIfCount2GE Count1, Value1, Count2, Value2, Statement + +.if (\Count1\() >= \Value1\()) && (\Count2\() >= \Value2\()) + \Statement\() +.endif + + .endm + +/*++ + +Macro Description: + + This macro emits the statement for each register listed in the register + list. The statement can use RegItem to access the current register. + +Arguments: + + RegList - Supplies the list of registers. + + Statement - Supplies the statement to emit. + +--*/ + + .macro EmitForEachRegister RegList, Statement + + .irp RegItem, \RegList\() + \Statement\() + .endr + + .endm diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 6c859e4e4f44b..7bb8b17031a84 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -50,7 +50,9 @@ Module Name: #include #endif #if defined(__x86_64__) || defined(__i386__) +#if !defined(signature_VORTEX_ebx) && !defined(signature_NEXGEN_ebx) && !defined(signature_AMD_ebx)//workaround for Bug 96238 - [i386] cpuid.h header needs include guards #include +#endif #if defined(__GNUC__) && __GNUC__ >= 12 #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // GCC 12 warns about uninitialized variables in immintrin.h. @@ -67,6 +69,9 @@ Module Name: #undef pixel #undef bool #endif +#if defined(__loongarch64) +#include +#endif #if defined(MLAS_TARGET_WASM_SIMD) #include #endif @@ -317,7 +322,8 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // Define the prototypes of the platform optimized routines. // -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || \ + defined(MLAS_TARGET_LARCH64) typedef size_t @@ -694,6 +700,30 @@ extern "C" { MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelPOWER10; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelVSX; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelVSX; +#elif defined(MLAS_TARGET_LARCH64) + MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelLSX; + MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelLasx; + MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelLSX; + MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelLasx; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelLSX; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelLSX; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelLSX; + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelLSX; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelLasx; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelLasx; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelLasx; + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelLasx; + MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelLSX; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelLSX; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelLSX; + MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelLasx; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelLasx; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelLasx; + MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE MlasSgemmTransposePackB16x4LSX; + MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE MlasSgemmTransposePackB16x4Lasx; + MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelLasx; + MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeSoftmaxOutputF32KernelLasx; + MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeLogSoftmaxOutputF32KernelLasx; #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; @@ -854,6 +884,7 @@ MlasSgemmOperation( struct MLAS_GEMM_QUANT_DISPATCH; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchLSX; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2; @@ -979,7 +1010,22 @@ struct MLAS_PLATFORM { #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; #endif - +#if defined(MLAS_TARGET_LARCH64) + const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; + const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; + MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; + MLAS_GEMM_DOUBLE_KERNEL* GemmDoubleKernel; + MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel; + MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; + MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; + MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; + MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* TransposePackB16x4Routine; + MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL* ReduceMaximumF32Kernel; + MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; + MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; + uint32_t NchwcBlockSize; +#endif #if defined(MLAS_TARGET_AMD64_IX86) const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; @@ -1256,6 +1302,8 @@ MlasConvDepthwiseFloat_CHW( #endif #elif defined(MLAS_TARGET_WASM_SIMD) #define MLAS_WASM_SIMD_INTRINSICS +#elif defined(MLAS_TARGET_LARCH64) +#define MLAS_LSX_INTRINSICS #endif #if defined(MLAS_NEON_INTRINSICS) @@ -1271,6 +1319,9 @@ typedef __vector unsigned MLAS_UINT32X4; #elif defined(MLAS_WASM_SIMD_INTRINSICS) typedef v128_t MLAS_FLOAT32X4; typedef v128_t MLAS_INT32X4; +#elif defined(MLAS_LSX_INTRINSICS) +typedef __m128 MLAS_FLOAT32X4; +typedef __m128i MLAS_INT32X4; #else typedef float MLAS_FLOAT32X4 __attribute__ ((vector_size(16))); typedef int32_t MLAS_INT32X4 __attribute__ ((vector_size(16))); @@ -1284,6 +1335,8 @@ MlasReinterpretAsInt32x4(MLAS_FLOAT32X4 Vector) return vreinterpretq_s32_f32(Vector); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_castps_si128(Vector); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_INT32X4)Vector; #else return MLAS_INT32X4(Vector); #endif @@ -1299,6 +1352,8 @@ MlasCastToInt32x4(MLAS_FLOAT32X4 Vector) return _mm_cvttps_epi32(Vector); #elif defined(MLAS_VSX_INTRINSICS) return vec_cts(Vector, 0); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vftint_w_s(Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return (MLAS_INT32X4)__builtin_convertvector((__f32x4)Vector, __i32x4); #else @@ -1318,6 +1373,8 @@ MlasCastToFloat32x4(MLAS_INT32X4 Vector) return vec_ctf(Vector, 0); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_convert_i32x4(Vector); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vffint_s_w(Vector); #else return MLAS_FLOAT32X4{float(Vector[0]), float(Vector[1]), float(Vector[2]), float(Vector[3])}; #endif @@ -1335,6 +1392,8 @@ MlasBroadcastInt32x4(int32_t Value) return wasm_i32x4_splat(Value); #elif defined(MLAS_VSX_INTRINSICS) return vec_splats(Value); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vreplgr2vr_w(Value); #else return MLAS_INT32X4{Value, Value, Value, Value}; #endif @@ -1352,6 +1411,8 @@ MlasLoadInt32x4(const int32_t* Buffer) return vec_vsx_ld(0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_load(Buffer); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vld((const MLAS_INT32X4*)Buffer, 0); #else return *((MLAS_INT32X4*)Buffer); #endif @@ -1369,6 +1430,8 @@ MlasStoreInt32x4(int32_t* Buffer, MLAS_INT32X4 Vector) vec_vsx_st(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); +#elif defined(MLAS_LSX_INTRINSICS) + __lsx_vst(Vector, (MLAS_INT32X4 *)Buffer, 0); #else *((MLAS_INT32X4*)Buffer) = Vector; #endif @@ -1386,6 +1449,8 @@ MlasAddInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return wasm_i32x4_add(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_add(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vadd_w(Vector1, Vector2); #else return Vector1 + Vector2; #endif @@ -1401,6 +1466,8 @@ MlasSubtractInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return _mm_sub_epi32(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_sub(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vsub_w(Vector1, Vector2); #else return Vector1 - Vector2; #endif @@ -1416,6 +1483,8 @@ MlasAndInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return _mm_and_si128(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_and(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vand_v(Vector1, Vector2); #else return Vector1 & Vector2; #endif @@ -1431,6 +1500,8 @@ MlasOrInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return _mm_or_si128(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_or(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vor_v(Vector1, Vector2); #else return Vector1 | Vector2; #endif @@ -1446,6 +1517,8 @@ MlasAndNotInt32x4(MLAS_INT32X4 VectorNot, MLAS_INT32X4 Vector) return _mm_andnot_si128(VectorNot, Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_andnot(Vector, VectorNot); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vandn_v(VectorNot, Vector); #else return (~VectorNot) & Vector; #endif @@ -1463,6 +1536,8 @@ MlasXorInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return wasm_v128_xor(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_xor(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vxor_v(Vector1, Vector2); #else return Vector1 ^ Vector2; #endif @@ -1486,6 +1561,8 @@ MlasShiftLeftInt32x4(MLAS_INT32X4 Vector) return _mm_slli_epi32(Vector, ShiftCount); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_shl(Vector, ShiftCount); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vslli_w(Vector, ShiftCount); #else return Vector << ShiftCount; #endif @@ -1505,6 +1582,8 @@ MlasMaximumInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return vec_vmaxsw(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_max(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vmax_w(Vector1, Vector2); #else return MlasBlendInt32x4(Vector2, Vector1, Vector1 > Vector2); #endif @@ -1524,6 +1603,8 @@ MlasMinimumInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return vec_vminsw(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_min(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vmin_w(Vector1, Vector2); #else return MlasBlendInt32x4(Vector2, Vector1, Vector2 > Vector1); #endif @@ -1537,6 +1618,8 @@ MlasReinterpretAsFloat32x4(MLAS_INT32X4 Vector) return vreinterpretq_f32_s32(Vector); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_castsi128_ps(Vector); +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT32X4(Vector); #else return MLAS_FLOAT32X4(Vector); #endif @@ -1556,6 +1639,8 @@ MlasBroadcastFloat32x4(float Value) // Suppress wrong GCC warnings MLAS_UNREFERENCED_PARAMETER(Value); return vec_splats(Value); +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT32X4{Value, Value, Value, Value}; #else return MLAS_FLOAT32X4{Value, Value, Value, Value}; #endif @@ -1573,6 +1658,8 @@ MlasBroadcastFloat32x4(const float* Value) return wasm_v128_load32_splat(Value); #elif defined(MLAS_VSX_INTRINSICS) return vec_splats(*Value); +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT32X4{*Value, *Value, *Value, *Value}; #else return MLAS_FLOAT32X4{*Value, *Value, *Value, *Value}; #endif @@ -1588,6 +1675,8 @@ MlasZeroFloat32x4(void) return _mm_setzero_ps(); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_const(0.0f, 0.0f, 0.0f, 0.0f); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasBroadcastFloat32x4(0.0f); #else return MlasBroadcastFloat32x4(0.0f); #endif @@ -1605,6 +1694,9 @@ MlasLoadFloat32x4(const float* Buffer) return vec_vsx_ld(0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_load(Buffer); +#elif defined(MLAS_LSX_INTRINSICS) + // return MlasReinterpretAsFloat32x4(__lsx_vld((const MLAS_INT32X4 *)Buffer, 0)); + return (MLAS_FLOAT32X4)__lsx_vld((const MLAS_INT32X4 *)Buffer, 0); #else return *((MLAS_FLOAT32X4*)Buffer); #endif @@ -1622,6 +1714,8 @@ MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) vec_vsx_st(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); +#elif defined(MLAS_LSX_INTRINSICS) + __lsx_vst(MlasReinterpretAsInt32x4(Vector), Buffer, 0); #else *((MLAS_FLOAT32X4*)Buffer) = Vector; #endif @@ -1642,6 +1736,8 @@ MlasStoreAlignedFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) vec_st(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); +#elif defined(MLAS_LSX_INTRINSICS) + MlasStoreFloat32x4(Buffer, Vector); #else MlasStoreFloat32x4(Buffer, Vector); #endif @@ -1660,6 +1756,8 @@ MlasStoreLaneFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) _mm_store_ss(Buffer, _mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Lane, Lane, Lane, Lane))); #elif defined(MLAS_WASM_SIMD_INTRINSICS) *Buffer = ((__f32x4)(Vector))[Lane]; +#elif defined(MLAS_LSX_INTRINSICS) + *Buffer = Vector[Lane]; #else *Buffer = Vector[Lane]; #endif @@ -1675,6 +1773,9 @@ MlasStoreLowHalfFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) _mm_storel_pi((__m64*)Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) *((long long*)Buffer) = ((__vector long long)Vector)[0]; +#elif defined(MLAS_LSX_INTRINSICS) + MlasStoreLaneFloat32x4<0>(&Buffer[0], Vector); + MlasStoreLaneFloat32x4<1>(&Buffer[1], Vector); #else MlasStoreLaneFloat32x4<0>(&Buffer[0], Vector); MlasStoreLaneFloat32x4<1>(&Buffer[1], Vector); @@ -1692,6 +1793,8 @@ MlasExtractLaneFloat32x4(MLAS_FLOAT32X4 Vector) return _mm_cvtss_f32(_mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Lane, Lane, Lane, Lane))); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_extract_lane(Vector, Lane); +#elif defined(MLAS_LSX_INTRINSICS) + return Vector[Lane]; #else return Vector[Lane]; #endif @@ -1736,6 +1839,9 @@ MlasShuffleFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return wasm_i32x4_shuffle(Vector1, Vector2, Index0, Index1, Index2, Index3); #elif defined(__clang__) return __builtin_shufflevector(Vector1, Vector2, Index0, Index1, Index2, Index3); +#elif defined(MLAS_LSX_INTRINSICS) + typedef int32_t GEN_INT32X4 __attribute__ ((vector_size(16))); + return __builtin_shuffle(Vector1, Vector2, GEN_INT32X4{Index0, Index1, Index2, Index3}); #else return __builtin_shuffle(Vector1, Vector2, MLAS_INT32X4{Index0, Index1, Index2, Index3}); #endif @@ -1764,6 +1870,8 @@ MlasInterleaveLowFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_unpacklo_ps(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_mergeh(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_FLOAT32X4)__lsx_vilvl_w(MlasReinterpretAsInt32x4(Vector2), MlasReinterpretAsInt32x4(Vector1)); #else return MlasShuffleFloat32x4<0, 4, 1, 5>(Vector1, Vector2); #endif @@ -1782,6 +1890,8 @@ MlasInterleaveHighFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_unpackhi_ps(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_mergel(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_FLOAT32X4)__lsx_vilvh_w(MlasReinterpretAsInt32x4(Vector2), MlasReinterpretAsInt32x4(Vector1)); #else return MlasShuffleFloat32x4<2, 6, 3, 7>(Vector1, Vector2); #endif @@ -1799,6 +1909,8 @@ MlasAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return wasm_f32x4_add(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_add(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfadd_s(Vector1, Vector2); #else return Vector1 + Vector2; #endif @@ -1816,6 +1928,8 @@ MlasSubtractFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return wasm_f32x4_sub(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_sub(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfsub_s(Vector1, Vector2); #else return Vector1 - Vector2; #endif @@ -1836,6 +1950,8 @@ MlasMultiplyFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) MLAS_UNREFERENCED_PARAMETER(Vector1); MLAS_UNREFERENCED_PARAMETER(Vector2); return vec_mul(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmul_s(Vector1, Vector2); #else return Vector1 * Vector2; #endif @@ -1855,6 +1971,8 @@ MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, MLAS_FL return vec_madd(Vector1, Vector2, Vector3); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_add(wasm_f32x4_mul(Vector1, Vector2), Vector3); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmadd_s(Vector1, Vector2, Vector3); #else return Vector1 * Vector2 + Vector3; #endif @@ -1890,6 +2008,8 @@ MlasDivideFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_div_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_div(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfdiv_s(Vector1, Vector2); #else return Vector1 / Vector2; #endif @@ -1907,6 +2027,8 @@ MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return wasm_f32x4_gt(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return MLAS_FLOAT32X4(vec_cmpgt(Vector1, Vector2)); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_FLOAT32X4)__lsx_vfcmp_clt_s(Vector2, Vector1); #else return Vector1 > Vector2; #endif @@ -1920,6 +2042,8 @@ MlasAndFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_and_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_and(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasReinterpretAsFloat32x4(MlasAndInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #else return MlasReinterpretAsFloat32x4(MlasAndInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #endif @@ -1933,6 +2057,8 @@ MlasOrFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_or_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_or(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasReinterpretAsFloat32x4(MlasOrInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #else return MlasReinterpretAsFloat32x4(MlasOrInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #endif @@ -1946,6 +2072,8 @@ MlasAndNotFloat32x4(MLAS_FLOAT32X4 VectorNot, MLAS_FLOAT32X4 Vector) return _mm_andnot_ps(VectorNot, Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_andnot(Vector, VectorNot); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasReinterpretAsFloat32x4(MlasAndNotInt32x4(MlasReinterpretAsInt32x4(VectorNot), MlasReinterpretAsInt32x4(Vector))); #else return MlasReinterpretAsFloat32x4(MlasAndNotInt32x4(MlasReinterpretAsInt32x4(VectorNot), MlasReinterpretAsInt32x4(Vector))); #endif @@ -1959,6 +2087,8 @@ MlasXorFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_xor_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_xor(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasReinterpretAsFloat32x4(MlasXorInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #else return MlasReinterpretAsFloat32x4(MlasXorInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #endif @@ -1984,6 +2114,8 @@ MlasMaximumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vec_sel(Vector2, Vector1, vec_cmpgt(Vector1, Vector2)); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_max(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmax_s(Vector1, Vector2); #else return MlasBlendFloat32x4(Vector2, Vector1, Vector1 > Vector2); #endif @@ -2002,6 +2134,8 @@ MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vec_sel(Vector2, Vector1, vec_cmpgt(Vector2, Vector1)); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_min(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmin_s(Vector1, Vector2); #else return MlasBlendFloat32x4(Vector2, Vector1, Vector2 > Vector1); #endif @@ -2108,6 +2242,8 @@ MlasPowerOf2Float32x4(MLAS_FLOAT32X4 Vector) typedef __m128d MLAS_FLOAT64X2; #elif defined(MLAS_VSX_INTRINSICS) typedef __vector double MLAS_FLOAT64X2; +#elif defined(MLAS_LSX_INTRINSICS) +typedef __m128d MLAS_FLOAT64X2; #else #define MLAS_FLOAT64X2_UNSUPPORTED #endif @@ -2129,6 +2265,27 @@ MlasMultiplyAddFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2, MLAS_FL return vec_madd(Vector1, Vector2, Vector3); } +MLAS_FORCEINLINE +MLAS_FLOAT64X2 +MlasBroadcastFloat64x2(const double *Value) +{ + return MLAS_FLOAT64X2{*Value, *Value}; +} +#elif defined(MLAS_LSX_INTRINSICS) +template +MLAS_FORCEINLINE +double +MlasExtractLaneFloat64x2(MLAS_FLOAT64X2 Vector) +{ + return Vector[Lane]; +} +MLAS_FORCEINLINE +MLAS_FLOAT64X2 +MlasMultiplyAddFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2, MLAS_FLOAT64X2 Vector3) +{ + return __lsx_vfmadd_d(Vector1, Vector2, Vector3); +} + MLAS_FORCEINLINE MLAS_FLOAT64X2 MlasBroadcastFloat64x2(const double *Value) @@ -2144,6 +2301,8 @@ MlasBroadcastFloat64x2(double Value) return _mm_set1_pd(Value); #elif defined(MLAS_VSX_INTRINSICS) return MLAS_FLOAT64X2{Value, Value}; +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT64X2{Value, Value}; #endif } @@ -2155,6 +2314,8 @@ MlasZeroFloat64x2(void) return _mm_setzero_pd(); #elif defined(MLAS_VSX_INTRINSICS) return MlasBroadcastFloat64x2(0.0f); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasBroadcastFloat64x2(0.0f); #endif } @@ -2166,6 +2327,8 @@ MlasLoadFloat64x2(const double* Buffer) return _mm_loadu_pd(Buffer); #elif defined(MLAS_VSX_INTRINSICS) return vec_vsx_ld(0, Buffer); +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT64X2(__lsx_vld((const MLAS_INT32X4 *)Buffer, 0)); #endif } @@ -2177,6 +2340,8 @@ MlasStoreFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) _mm_storeu_pd(Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) vec_vsx_st(Vector, 0, Buffer); +#elif defined(MLAS_LSX_INTRINSICS) + (__lsx_vst(MLAS_INT32X4(Vector), Buffer, 0)); #endif } @@ -2188,6 +2353,8 @@ MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) _mm_store_pd(Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) *((MLAS_FLOAT64X2*)Buffer) = Vector; +#elif defined(MLAS_LSX_INTRINSICS) + (__lsx_vst(MLAS_INT32X4(Vector), Buffer, 0)); #endif } @@ -2199,6 +2366,8 @@ MlasMultiplyFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2) return _mm_mul_pd(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return Vector1 * Vector2; +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmul_d(Vector1, Vector2); #endif } @@ -2233,6 +2402,17 @@ MlasReadTimeStampCounter(void) ); return ((uint64_t)edx << 32) | eax; +#elif defined(MLAS_TARGET_LARCH64) + uint64_t time_cnt, id; + + __asm__ __volatile__ + ( + "rdtime.d %0, %1\n\t" + : "=r" (time_cnt), "=r" (id) + :: + ); + + return time_cnt; #else return 0; #endif diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index fec56c6ee063f..8329a34f1338f 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -185,6 +185,28 @@ MlasInitAMX() #endif // MLAS_TARGET_AMD64_IX86 +#ifdef MLAS_TARGET_LARCH64 + +#if defined(__linux__) +#include +#include +#endif +// +// Stores a vector to build a conditional load/store mask for vmaskmovps. +// + +MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveLasx[8], 32) = { 0, 1, 2, 3, 4, 5, 6, 7 }; + +// +// Stores a table of AVX vmaskmovps/vmaskmovpd load/store masks. +// + +MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveTableLasx[16], 32) = { + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, +}; + +#endif MLAS_PLATFORM::MLAS_PLATFORM( void ) @@ -536,6 +558,63 @@ Return Value: #endif // __linux__ #endif // MLAS_TARGET_POWER +#if defined(MLAS_TARGET_LARCH64) + + // + // Default to the baseline LSX support. + // + + int hwcap = getauxval(AT_HWCAP); + bool cap_lasx = hwcap & HWCAP_LOONGARCH_LASX; + bool cap_lsx = hwcap & HWCAP_LOONGARCH_LSX; + + if( cap_lasx ){ + this->GemmFloatKernel = MlasGemmFloatKernelLasx; + this->GemmDoubleKernel = MlasGemmDoubleKernelLasx; + this->ConvNchwFloatKernel = MlasConvNchwFloatKernelLasx; + this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelLasx; + this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelLasx; + this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelLasx; + this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelLasx; + this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelLasx; + this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelLasx; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelLasx; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32KernelLasx; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelLasx; + this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Lasx; + + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX; + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX; + }else if( cap_lsx ){ + this->GemmFloatKernel = MlasGemmFloatKernelLSX; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX; + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX; + this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4LSX; + this->GemmDoubleKernel = MlasGemmDoubleKernelLSX; + this->ConvNchwFloatKernel = MlasConvNchwFloatKernelLSX; + this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelLSX; + this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelLSX; + this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelLSX; + + this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelLSX; + this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelLSX; + this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelLSX; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; + }else{ + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; + } + + this->NchwcBlockSize = 8; + // this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; + + // this->MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; + +#endif // MLAS_TARGET_LARCH64 + } size_t diff --git a/onnxruntime/core/mlas/lib/pooling.cpp b/onnxruntime/core/mlas/lib/pooling.cpp index 12128f6c700fd..50dcf19224510 100644 --- a/onnxruntime/core/mlas/lib/pooling.cpp +++ b/onnxruntime/core/mlas/lib/pooling.cpp @@ -1569,6 +1569,96 @@ Return Value: c -= 16; } +#elif defined(MLAS_LSX_INTRINSICS) + uint32_t val = 0x80808080; + const __m128i BitFlipVector = __lsx_vreplgr2vr_w(val); + if constexpr (std::is_unsigned::value) { + MLAS_UNREFERENCED_PARAMETER(BitFlipVector); + } + + while (c >= 32) { + + __m128i MaximumVector0 = __lsx_vldi(0); + __m128i MaximumVector1 = __lsx_vldi(0); + + for (size_t k = 0; k < KernelSize; k++) { + + __m128i InputVector0 = __lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0); + __m128i InputVector1 = __lsx_vld((const __m128i*)&Input[k][ChannelOffset + 16], 0); + + if constexpr (std::is_signed::value) { + InputVector0 = __lsx_vxor_v(InputVector0, BitFlipVector); + InputVector1 = __lsx_vxor_v(InputVector1, BitFlipVector); + } + + MaximumVector0 = __lsx_vmax_bu(MaximumVector0, InputVector0); + MaximumVector1 = __lsx_vmax_bu(MaximumVector1, InputVector1); + } + + if constexpr (std::is_signed::value) { + MaximumVector0 = __lsx_vxor_v(MaximumVector0, BitFlipVector); + MaximumVector1 = __lsx_vxor_v(MaximumVector1, BitFlipVector); + } + + __lsx_vst(MaximumVector0, (__m128i*)&Output[0], 0); + __lsx_vst(MaximumVector1, (__m128i*)&Output[16], 0); + Output += 32; + + ChannelOffset += 32; + c -= 32; + } + + while (c >= 16) { + + __m128i MaximumVector0 = __lsx_vldi(0); + + for (size_t k = 0; k < KernelSize; k++) { + + __m128i InputVector0 = __lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0); + + if constexpr (std::is_signed::value){ + InputVector0 = __lsx_vxor_v(InputVector0, BitFlipVector); + } + + MaximumVector0 = __lsx_vmax_bu(MaximumVector0, InputVector0); + } + + if constexpr (std::is_signed::value) { + MaximumVector0 = __lsx_vxor_v(MaximumVector0, BitFlipVector); + } + + __lsx_vst(MaximumVector0, (__m128i*)&Output[0], 0); + Output += 16; + + ChannelOffset += 16; + c -= 16; + } + + if (c >= 8) { + + __m128i MaximumVector0 = __lsx_vldi(0); + + for (size_t k = 0; k < KernelSize; k++) { + + __m128i InputVector0 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0), 0, 1); + + if constexpr (std::is_signed::value){ + InputVector0 = __lsx_vxor_v(InputVector0, BitFlipVector); + } + + MaximumVector0 = __lsx_vmax_bu(MaximumVector0, InputVector0); + } + + if constexpr (std::is_signed::value) { + MaximumVector0 = __lsx_vxor_v(MaximumVector0, BitFlipVector); + } + + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i*)&Output[0] , 0), __lsx_vpickve2gr_d(MaximumVector0, 0), 0), (__m128i*)&Output[0], 0); + Output += 8; + + ChannelOffset += 8; + c -= 8; + } #endif while (c > 0) { diff --git a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp index 830a3a6a492db..1fed8af21b31c 100644 --- a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp +++ b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp @@ -86,11 +86,11 @@ Return Value: if constexpr (std::is_same_v || std::is_same_v) { auto CharVector = vec_pack(ShortVector0, ShortVector1); - vec_xst(CharVector, 0, Output); + vec_xst(CharVector, 0, (int8_t *)Output); } else { static_assert(std::is_same_v || std::is_same_v); - vec_xst(ShortVector0, 0, Output); - vec_xst(ShortVector1, 0, &Output[8]); + vec_xst(ShortVector0, 0, (int16_t *)Output); + vec_xst(ShortVector1, 0, (int16_t *)&Output[8]); } Output += 16; diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 48d975a7fd26d..b5784ecb56d01 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -779,6 +779,17 @@ MlasBlockwiseQuantMetaShape( int& meta_cols ); +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + template void MlasBlockwiseQuantizedShape( @@ -790,6 +801,16 @@ MlasBlockwiseQuantizedShape( int& q_cols ); +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); void MLASCALL MlasBlockwiseQuantizedBufferSizes( diff --git a/onnxruntime/core/mlas/lib/q4gemm.h b/onnxruntime/core/mlas/lib/q4gemm.h index b1b51dd53c4fc..d16798eb8945f 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.h +++ b/onnxruntime/core/mlas/lib/q4gemm.h @@ -126,7 +126,7 @@ MlasQ4GemmOperation( size_t RowsRemaining = RangeCountM; while (RowsRemaining > 0) { -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) auto RowsHandled = GetMlasPlatform().GemmFloatKernel( a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true); #else diff --git a/onnxruntime/core/mlas/lib/qdwconv.cpp b/onnxruntime/core/mlas/lib/qdwconv.cpp index 924009ab5ccf4..59f6877f70d56 100644 --- a/onnxruntime/core/mlas/lib/qdwconv.cpp +++ b/onnxruntime/core/mlas/lib/qdwconv.cpp @@ -41,6 +41,10 @@ MlasConvDepthwiseKernel( #elif defined(MLAS_NEON_INTRINSICS) const uint8x8_t InputZeroPointVector = vdup_n_u8(uint8_t(InputZeroPoint)); const uint8x8_t FilterZeroPointVector = vdup_n_u8(uint8_t(FilterZeroPoint)); +#elif defined(MLAS_LSX_INTRINSICS) + const __m128i ZeroVector = __lsx_vldi(0); + const __m128i InputZeroPointVector = __lsx_vreplgr2vr_h(InputZeroPoint); + const __m128i FilterZeroPointVector = __lsx_vreplgr2vr_h(FilterZeroPoint); #endif while (OutputCount > 0) { @@ -141,6 +145,54 @@ MlasConvDepthwiseKernel( vst1q_s32(&Output[4], Accumulator1); Output += 8; + ChannelOffset += 8; + c -= 8; + } +#elif defined(MLAS_LSX_INTRINSICS) + + while (c >= 8) { + __m128i Accumulator0 = __lsx_vldi(0); + __m128i Accumulator1 = __lsx_vldi(0); + size_t ChannelKernelOffset = ChannelOffset; + + for (size_t k = 0; k < KernelSize; k++) { + __m128i InputVector = __lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0); + __lsx_vinsgr2vr_d(InputVector, 0, 1); + __m128i FilterVector = + __lsx_vld((const __m128i*)&Filter[ChannelKernelOffset], 0); + __lsx_vinsgr2vr_d(FilterVector, 0, 1); + + if (std::is_signed::value) { + InputVector = __lsx_vsrai_h(__lsx_vilvl_b(InputVector, ZeroVector), 8); + } else { + InputVector = __lsx_vilvl_b(ZeroVector, InputVector ); + } + + if (std::is_signed::value) { + FilterVector = __lsx_vsrai_h(__lsx_vilvl_b(FilterVector, ZeroVector), 8); + } else { + FilterVector = __lsx_vilvl_b(ZeroVector, FilterVector); + } + + InputVector = __lsx_vsub_h(InputVector, InputZeroPointVector); + FilterVector = __lsx_vsub_h(FilterVector, FilterZeroPointVector); + + // N.B. Emulate PMULLD functionality on LSX by computing the low + // and high parts of the result and interleaving the results. + __m128i MultiplyLowWords = __lsx_vmul_h(InputVector, FilterVector); + __m128i MultiplyHighWords = __lsx_vmuh_h(InputVector, FilterVector); + __m128i Multiply0 = __lsx_vilvl_h(MultiplyHighWords, MultiplyLowWords); + __m128i Multiply1 = __lsx_vilvh_h(MultiplyHighWords, MultiplyLowWords); + + Accumulator0 = __lsx_vadd_w(Accumulator0, Multiply0); + Accumulator1 = __lsx_vadd_w(Accumulator1, Multiply1); + ChannelKernelOffset += Channels; + } + + __lsx_vst(Accumulator0, (__m128i*)&Output[0], 0); + __lsx_vst(Accumulator1, (__m128i*)&Output[4], 0); + Output += 8; + ChannelOffset += 8; c -= 8; } @@ -322,4 +374,4 @@ Return Value: ); } } -} \ No newline at end of file +} diff --git a/onnxruntime/core/mlas/lib/qgemm.h b/onnxruntime/core/mlas/lib/qgemm.h index 1fcd44e78a28c..75c17a6b5a177 100644 --- a/onnxruntime/core/mlas/lib/qgemm.h +++ b/onnxruntime/core/mlas/lib/qgemm.h @@ -871,7 +871,7 @@ MlasGemmQuantGetDispatch( GemmQuantDispatch = &MlasGemmQuantDispatchDefault; } -#if defined(MLAS_TARGET_AMD64_IX86) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_LARCH64) if (!AIsSigned) { if (BIsSigned) { GemmQuantDispatch = GetMlasPlatform().GemmU8S8Dispatch; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_lsx.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_lsx.cpp new file mode 100644 index 0000000000000..7d5817335bd77 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_lsx.cpp @@ -0,0 +1,531 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_lsx.cpp + +Abstract: + + This module implements QGEMM kernels for LSX. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" +#include + +struct MLAS_GEMM_U8X8_KERNEL_LSX +{ + typedef int16_t PackedAType; + typedef int16_t PackedBType; + typedef uint8_t OffsetAType; + typedef int8_t OffsetBType; + + static constexpr size_t PackedK = 2; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 12, 128, 128 }; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{0, 0, 0}; +}; + +constexpr size_t MLAS_GEMM_U8X8_KERNEL_LSX::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_LSX::Strides; + +template<> +MLAS_FORCEINLINE constexpr +int32_t +MlasGemmQuantFixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + if (!BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_LSX::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template<> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_U8X8_KERNEL_LSX::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned + ) +{ + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + const __m128i ZeroVector = __lsx_vrepli_d(0); + uint16_t val = 1; + const __m128i OnesWordBroadcast = __lsx_vreplgr2vr_h(val); + uint8_t PaddedMatrixAData[8] = { 0 }; + + // + // Process a single row of matrix A in a loop. + // + + while (CountM > 0) { + + const uint8_t* a = A; + size_t k = CountK; + __m128i ReductionVector = ZeroVector; + + // + // Zero extend the source bytes to 16-bits and write to the packed + // buffer. + // + // The packed buffer has the same data ordering as the source bytes, + // but CountK is aligned up to a multiple of 2 to maintain 32-bit + // alignment. All extra bytes are zero-padded. + // + // These 16-bit values are also accumulated into an intermediate per-row + // accumulator. CountK cannot be greater than 128 to avoid overflowing + // these signed 16-bit accumulators. + // + + while (k >= 8) { + + __m128i Bytes = __lsx_vld((const __m128i*) & a[0], 0); + __lsx_vinsgr2vr_d(Bytes, 0, 1); + __m128i Words = __lsx_vilvl_b(ZeroVector, Bytes); + + ReductionVector = __lsx_vadd_h(ReductionVector, Words); + + __lsx_vst(Words, (__m128i*) & D[0], 0); + + a += 8; + D += 8; + k -= 8; + } + + if (k > 0) { + + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + uint8_t* padded = PaddedMatrixAData; + uint8_t* padded_end = padded + k; + + do { + padded[0] = a[0]; + padded++; + a++; + } while (padded < padded_end); + + __m128i Bytes = __lsx_vld((__m128i*)PaddedMatrixAData, 0); + __lsx_vinsgr2vr_d(Bytes, 0, 1); + __m128i Words = __lsx_vilvl_b(ZeroVector, Bytes); + + ReductionVector = __lsx_vadd_h(ReductionVector, Words); + + // + // Copy pairs of 16-bit values from the vector to the packed + // buffer and rotate the vector for the next iteration. + // + + for (size_t pairs = (k + 1) / 2; pairs > 0; pairs--) { + __lsx_vstelm_w(Words, (int32_t*)D, 0 , 0); + D += 2; + Words = __lsx_vshuf4i_w(Words, 0x39); //(0, 3, 2, 1) + } + } + + // + // Reduce the partial accumulators. + // + __m128i tmp1 = ZeroVector, tmp2 = ZeroVector; + tmp1 = __lsx_vmaddwev_w_h(tmp1, ReductionVector, OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ReductionVector, OnesWordBroadcast); + ReductionVector = __lsx_vadd_w(tmp1, tmp2); + ReductionVector = __lsx_vadd_w(ReductionVector, + __lsx_vshuf4i_w(ReductionVector, 0xee)); + ReductionVector = __lsx_vadd_w(ReductionVector, + __lsx_vshuf4i_w(ReductionVector, 0x11)); + + __lsx_vstelm_w(ReductionVector, RowSumBuffer++, 0 , 0); + + A += lda; + CountM -= 1; + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackBProcessLSX( + MLAS_GEMM_U8X8_KERNEL_LSX::PackedBType* D, + __m128i BytesRow0, + __m128i BytesRow1, + __m128i BitFlipVector, + __m128i ColumnSums[2] +) +{ + __m128i BytesInterleaved = __lsx_vilvl_b(BytesRow1, BytesRow0); + + BytesInterleaved = __lsx_vxor_v(BytesInterleaved, BitFlipVector); + + __m128i WordsInterleaved0 = __lsx_vsrai_h(__lsx_vilvl_b(BytesInterleaved, BytesInterleaved), 8); + __m128i WordsInterleaved1 = __lsx_vsrai_h(__lsx_vilvh_b(BytesInterleaved, BytesInterleaved), 8); + + ColumnSums[0] = __lsx_vadd_h(ColumnSums[0], WordsInterleaved0); + ColumnSums[1] = __lsx_vadd_h(ColumnSums[1], WordsInterleaved1); + + __lsx_vst(WordsInterleaved0, (__m128i*) & D[0], 0); + __lsx_vst(WordsInterleaved1, (__m128i*) & D[8], 0); +} + +template<> +void +MlasGemmQuantCopyPackB( + MLAS_GEMM_U8X8_KERNEL_LSX::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) +{ + uint16_t val = 1; + const __m128i OnesWordBroadcast = __lsx_vreplgr2vr_h(val); + const __m128i BitFlipVector = __lsx_vreplgr2vr_w(BIsSigned ? 0 : 0x80808080); + + // + // Process 8 columns of matrix B in a loop. + // + + while (CountN >= 8) { + + const uint8_t* b = B; + size_t k = CountK; + __m128i ColumnSums[2]; + + ColumnSums[0] = __lsx_vldi(0); + ColumnSums[1] = __lsx_vldi(0); + + // + // Interleave rows of matrix B and write to the packed buffer. + // + // These values are also zero-extended and accumulated into an + // intermediate per-column accumulator. CountK cannot be greater than + // 128 to avoid overflowing these signed 16-bit accumulators. + // + + while (k >= MLAS_GEMM_U8X8_KERNEL_LSX::PackedK) { + + __m128i BytesRow0 = __lsx_vld((const __m128i*) & b[0], 0); + __lsx_vinsgr2vr_d(BytesRow0, 0, 1); + __m128i BytesRow1 = __lsx_vld((const __m128i*) & b[ldb], 0); + __lsx_vinsgr2vr_d(BytesRow1, 0, 1); + + MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums); + + b += ldb * 2; + D += 16; + k -= 2; + } + + if (k > 0) { + + __m128i BytesRow0 = __lsx_vld((const __m128i*) & b[0], 0); + __lsx_vinsgr2vr_d(BytesRow0, 0, 1); + + MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); + + D += 16; + } + + __m128i tmp1, tmp2; + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[0], OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[0], OnesWordBroadcast); + ColumnSums[0]= __lsx_vadd_w(tmp1, tmp2); + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[1], OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[1], OnesWordBroadcast); + ColumnSums[1]= __lsx_vadd_w(tmp1, tmp2); + + __lsx_vst(ColumnSums[0], (__m128i*) & ColumnSumBuffer[0], 0); + __lsx_vst(ColumnSums[1], (__m128i*) & ColumnSumBuffer[4], 0); + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + + const uint8_t* b = B; + size_t k = CountK; + __m128i ColumnSums[2]; + uint8_t PaddedMatrixBData[16]; + + __lsx_vst(BitFlipVector, (__m128i*)PaddedMatrixBData, 0); + + ColumnSums[0] = __lsx_vldi(0); + ColumnSums[1] = __lsx_vldi(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k >= MLAS_GEMM_U8X8_KERNEL_LSX::PackedK) { + + const uint8_t* bcopy = b; + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + do { + padded[0] = bcopy[0]; + padded[8] = bcopy[ldb]; + padded++; + bcopy++; + } while (padded < padded_end); + + __m128i BytesRow0 = __lsx_vld((__m128i*) & PaddedMatrixBData[0], 0); + __lsx_vinsgr2vr_d(BytesRow0, 0, 1); + __m128i BytesRow1 = __lsx_vld((__m128i*) & PaddedMatrixBData[8], 0); + __lsx_vinsgr2vr_d(BytesRow1, 0, 1); + + MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums); + + b += ldb * 2; + D += 16; + k -= 2; + } + + if (k > 0) { + + const uint8_t* bcopy = b; + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + do { + padded[0] = bcopy[0]; + padded++; + bcopy++; + } while (padded < padded_end); + + __m128i BytesRow0 = __lsx_vld((__m128i*) & PaddedMatrixBData[0], 0); + __lsx_vinsgr2vr_d(BytesRow0, 0, 1); + + MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); + } + + __m128i tmp1, tmp2; + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[0], OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[0], OnesWordBroadcast); + ColumnSums[0]= __lsx_vadd_w(tmp1, tmp2); + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[1], OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[1], OnesWordBroadcast); + ColumnSums[1]= __lsx_vadd_w(tmp1, tmp2); + + __lsx_vst(ColumnSums[0], (__m128i*) & ColumnSumBuffer[0], 0); + __lsx_vst(ColumnSums[1], (__m128i*) & ColumnSumBuffer[4], 0); + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8MultiplyAccumulateRowLSX( + __m128i ABroadcast, + const int16_t* B, + __m128i Accumulators[2] +) +{ + __m128i BElements0 = __lsx_vld((__m128i*) & B[0], 0); + __m128i BElements1 = __lsx_vld((__m128i*) & B[8], 0); + + __m128i tmp1, tmp2; + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, BElements0, ABroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, BElements0, ABroadcast); + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vadd_w(tmp1, tmp2)); + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, BElements1, ABroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, BElements1, ABroadcast); + Accumulators[1] = __lsx_vadd_w(Accumulators[1], __lsx_vadd_w(tmp1, tmp2)); +} + +template<> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_U8X8_KERNEL_LSX::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_LSX::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + MLAS_UNREFERENCED_PARAMETER(CountM); + MLAS_UNREFERENCED_PARAMETER(ldc); + + while (CountN > 0) { + + __m128i Accumulators[2]; + + // + // Initialize the accumulators with the row and column sums. + // + + int32_t RowSumValue = RowSumBuffer[0]; + + if (ZeroPointB != nullptr) { + + int32_t ScaledRowSumBuffer[8]; + + for (size_t i = 0; i < 8; i++) { + ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; + } + + ZeroPointB += 8; + + Accumulators[0] = __lsx_vld((__m128i*) & ScaledRowSumBuffer[0], 0); + Accumulators[1] = __lsx_vld((__m128i*) & ScaledRowSumBuffer[4], 0); + + } + else { + + Accumulators[0] = __lsx_vreplgr2vr_w(RowSumValue); + Accumulators[1] = Accumulators[0]; + } + + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vld((const __m128i*) & ColumnSumBuffer[0], 0)); + Accumulators[1] = __lsx_vadd_w(Accumulators[1], __lsx_vld((const __m128i*) & ColumnSumBuffer[4], 0)); + ColumnSumBuffer += 8; + + // + // Broadcast each pair of 16-bit values from the matrix A and multiply + // with the pair of 16-bit values from matrix B, and add the 32-bit + // intermediate into the accumulator registers. + // + + const int16_t* a = A; + size_t k = PackedCountK; + + while (k >= 4) { + + __m128i AElements = __lsx_vld((__m128i*)a, 0); + __m128i ABroadcast; + + ABroadcast = __lsx_vreplvei_w(AElements, 0); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[0], Accumulators); + + ABroadcast = __lsx_vreplvei_w(AElements, 1); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[16], Accumulators); + + ABroadcast = __lsx_vreplvei_w(AElements, 2); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[32], Accumulators); + + ABroadcast = __lsx_vreplvei_w(AElements, 3); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[48], Accumulators); + + a += 4 * 2; + B += 4 * 16; + k -= 4; + } + + while (k > 0) { + + __m128i ABroadcast = __lsx_vldrepl_w((int32_t*)a, 0); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[0], Accumulators); + + a += 2; + B += 16; + k -= 1; + } + + // + // Output the accumulator block after optionally accumulating the values + // from matrix C. + // + + if (CountN >= 8) { + + if (!ZeroMode) { + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vld((__m128i*) & C[0], 0)); + Accumulators[1] = __lsx_vadd_w(Accumulators[1], __lsx_vld((__m128i*) & C[4], 0)); + } + + __lsx_vst(Accumulators[0], (__m128i*) & C[0], 0); + __lsx_vst(Accumulators[1], (__m128i*) & C[4], 0); + + C += 8; + CountN -= 8; + + } + else { + + // + // Output the remaining partial output block. + // + + if ((CountN & 4) != 0) { + + if (!ZeroMode) { + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vld((__m128i*) & C[0], 0)); + } + + __lsx_vst(Accumulators[0], (__m128i*) & C[0], 0); + C += 4; + + Accumulators[0] = Accumulators[1]; + } + + if ((CountN & 2) != 0) { + + if (!ZeroMode) { + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vinsgr2vr_d(__lsx_vld((__m128i*) & C[0], 0), 0, 1)); + } + + *((uint64_t *)&C[0]) = __lsx_vpickve2gr_d(Accumulators[0], 0); + C += 2; + + Accumulators[0] = __lsx_vshuf4i_w(Accumulators[0], 0xee); + } + + if ((CountN & 1) != 0) { + + int32_t AccumulatorValue = __lsx_vpickve2gr_w(Accumulators[0], 0); + + if (!ZeroMode) { + AccumulatorValue += C[0]; + } + + C[0] = AccumulatorValue; + } + + CountN = 0; + } + } + + return 1; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchLSX = { + MlasGemmQuantOperation, + nullptr, + nullptr, + MLAS_GEMM_U8X8_KERNEL_LSX::PackedK, + 0, + 1 // aLSXmbly kernel M stride +}; diff --git a/onnxruntime/core/mlas/lib/qladd.cpp b/onnxruntime/core/mlas/lib/qladd.cpp index 971ea0161d7af..5dafa17c2ae66 100644 --- a/onnxruntime/core/mlas/lib/qladd.cpp +++ b/onnxruntime/core/mlas/lib/qladd.cpp @@ -552,6 +552,119 @@ MlasQLinearAddKernelHelper( InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); } } +#elif defined(MLAS_LSX_INTRINSICS) + +template +static +void +MlasQLinearAddKernelHelper( + const DataType* InputA, + float ScaleA, + int32_t ZeroPointA, + const DataType* InputB, + float ScaleB, + int32_t ZeroPointB, + float ScaleC, + int32_t ZeroPointC, + DataType* OutputC, + size_t N + ) +{ + const float ScaleRatio_AC = ScaleA / ScaleC; + const float ScaleRatio_BC = ScaleB / ScaleC; + const auto VectorScaleRatio_AC = MlasBroadcastFloat32x4(ScaleRatio_AC); + const auto VectorScaleRatio_BC = MlasBroadcastFloat32x4(ScaleRatio_BC); + auto VectorFixedPart = MlasBroadcastFloat32x4((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB)); + + MLAS_FLOAT32X4 va_lo, va_hi, vb_lo, vb_hi; + if (IsScalarB) { + float tmp_f = (float)*InputB; + uint32_t *tmp_p = (uint32_t *)&tmp_f; + vb_lo = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w(*tmp_p)); + VectorFixedPart = __lsx_vfmadd_s(vb_lo, VectorScaleRatio_BC, VectorFixedPart); + } + + __m128i tmp, tmp1; + + while (N >= 8) { + const auto va_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)InputA, 0), 0 ,1); + const auto va_i16x8 = __lsx_vilvl_b(va_low_half, va_low_half); + InputA += 8; + va_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(va_i16x8, va_i16x8), 24)); + va_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(va_i16x8, va_i16x8), 24)); + + if (!IsScalarB) { + const auto vb_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)InputB, 0), 0 ,1); + const auto vb_i16x8 = __lsx_vilvl_b(vb_low_half, vb_low_half); + InputB += 8; + vb_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(vb_i16x8, vb_i16x8), 24)); + vb_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(vb_i16x8, vb_i16x8), 24)); + } + + MLAS_INT32X4 r_lo, r_hi; + if (IsScalarB) { + r_lo = __lsx_vftint_w_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart)); + r_hi = __lsx_vftint_w_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart)); + } else { + r_lo = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_lo, VectorScaleRatio_BC))); + r_hi = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_hi, VectorScaleRatio_BC))); + } + tmp = __lsx_vsat_w(r_lo, 15); + tmp1 = __lsx_vsat_w(r_hi, 15); + const auto vc_i16x8 = __lsx_vpickev_h(tmp1, tmp); + + MLAS_INT32X4 vc = MlasPackS16_128(vc_i16x8, vc_i16x8); + + N -= 8; + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((MLAS_INT32X4*)OutputC, 0), __lsx_vpickve2gr_d(vc, 0), 0), (MLAS_INT32X4*)OutputC, 0); + OutputC += 8; + } + + if (N > 0) { + uint8_t TailData[8] = { 0 }; + + MlasCopyTailBytes(TailData, (const uint8_t*)InputA, N); + const auto va_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)TailData, 0), 0 ,1); + const auto va_i16x8 = __lsx_vilvl_b(va_low_half, va_low_half); + va_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(va_i16x8, va_i16x8), 24)); + va_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(va_i16x8, va_i16x8), 24)); + + if (!IsScalarB) { + MlasCopyTailBytes(TailData, (const uint8_t*)InputB, N); + const auto vb_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)TailData, 0), 0 ,1); + const auto vb_i16x8 = __lsx_vilvl_b(vb_low_half, vb_low_half); + vb_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(vb_i16x8, vb_i16x8), 24)); + vb_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(vb_i16x8, vb_i16x8), 24)); + } + + MLAS_INT32X4 r_lo, r_hi; + if (IsScalarB) { + r_lo = __lsx_vftint_w_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart)); + r_hi = __lsx_vftint_w_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart)); + } else { + r_lo = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_lo, VectorScaleRatio_BC))); + r_hi = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_hi, VectorScaleRatio_BC))); + } + tmp = __lsx_vsat_w(r_lo, 15); + tmp1 = __lsx_vsat_w(r_hi, 15); + const auto vc_i16x8 = __lsx_vpickev_h(tmp1, tmp); + + MLAS_INT32X4 vc = MlasPackS16_128(vc_i16x8, vc_i16x8); + + if (N & 4) { + __lsx_vstelm_w(vc, (int*)OutputC, 0, 0); + N -= 4; + OutputC += 4; + vc = __lsx_vshuf4i_w(vc, 0x39); //_MM_SHUFFLE(0, 3, 2, 1) + } + + uint32_t PackedValueC = (uint32_t)__lsx_vpickve2gr_w(vc, 0); + for (size_t i = 0; i < N; ++i) { + *((uint8_t*)OutputC + i) = (uint8_t)PackedValueC; + PackedValueC >>= 8; + } + } +} #else template diff --git a/onnxruntime/core/mlas/lib/qladd.h b/onnxruntime/core/mlas/lib/qladd.h index 8c05a6185324a..94568941a5660 100644 --- a/onnxruntime/core/mlas/lib/qladd.h +++ b/onnxruntime/core/mlas/lib/qladd.h @@ -463,5 +463,132 @@ MlasPackS16_128( { return reinterpret_cast(vec_packs(a, b)); } +#elif defined(MLAS_LSX_INTRINSICS) +#define LSX_DBG 1 +template +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt32( + MLAS_INT32X4 v, + int imm + ); + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt32( + MLAS_INT32X4 v, + int imm + ) +{ +#if LSX_DBG + MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_w(imm); + return __lsx_vsra_w(v, imm_v); +#else + return __lsx_vsrai_w(v, imm); +#endif +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt32( + MLAS_INT32X4 v, + int imm + ) +{ +#if LSX_DBG + MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_w(imm); + return __lsx_vsrl_w(v, imm_v); +#else + return __lsx_vsrli_w(v, imm); +#endif +} + +template +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt16( + MLAS_INT32X4 v, + int imm + ); + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt16( + MLAS_INT32X4 v, + int imm + ) +{ +#if LSX_DBG + MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_h(imm); + return __lsx_vsra_h(v, imm_v); +#else + return __lsx_vsrai_h(v, imm); +#endif +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt16( + MLAS_INT32X4 v, + int imm + ) +{ +#if LSX_DBG + MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_h(imm); + return __lsx_vsrl_h(v, imm_v); +#else + return __lsx_vsrli_h(v, imm); +#endif +} + +template +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + MLAS_INT32X4 a, + MLAS_INT32X4 b + ); + +template <> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + MLAS_INT32X4 a, + MLAS_INT32X4 b + ) +{ + // return _mm_packus_epi16(a, b); + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2, tmp3; + + tmp = __lsx_vmax_h(zero, a); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(zero, b); + tmp3 = __lsx_vsat_hu(tmp, 7); + return __lsx_vpickev_b(tmp3, tmp2); + +} + +template <> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + MLAS_INT32X4 a, + MLAS_INT32X4 b + ) +{ + // return _mm_packs_epi16(a, b); + __m128i tmp, tmp1; + + tmp = __lsx_vsat_h(a, 7); + tmp1 = __lsx_vsat_h(b, 7); + return __lsx_vpickev_b(tmp1, tmp); + +} #endif diff --git a/onnxruntime/core/mlas/lib/qlgavgpool.cpp b/onnxruntime/core/mlas/lib/qlgavgpool.cpp index 1c2be0a833a3e..e44d7ad25c446 100644 --- a/onnxruntime/core/mlas/lib/qlgavgpool.cpp +++ b/onnxruntime/core/mlas/lib/qlgavgpool.cpp @@ -689,6 +689,316 @@ MlasQLinearGlobalAveragePoolNhwcSingleBatch( Output_zero_point, 0, 0, 1, Channels); } +#elif defined(MLAS_LSX_INTRINSICS) + +template +void MLASCALL +MlasQLinearGlobalAveragePoolNchw( + const T8Bits* Input, + float ScaleInput, + int32_t ZeroPointInput, + T8Bits* Output, + float ScaleOutput, + int32_t ZeroPointOutput, + size_t Channels, + size_t ImageSize, + int32_t* AccumulateBuffer + ) +{ + float scale = CheckQLinearGlobalAveragePoolScaleAndSize(ScaleInput, ScaleOutput, ImageSize); + const int32_t bias[] = {-ZeroPointInput * static_cast(ImageSize), 0, 0, 0}; + const auto vbias = __lsx_vld((const __m128i*)&bias, 0); + const auto vzero = __lsx_vldi(0); + uint8_t buffer[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + + int32_t* sum_buffer = AccumulateBuffer; + for (size_t c = Channels; c > 0; c--) { + + __m128i vacc_lo = vbias; + __m128i vacc_hi = vzero; + auto Len = ImageSize; + for (; Len >= 32; Len -= 32) { + + const __m128i vi0 = __lsx_vld((const __m128i*)Input, 0); + __lsx_vinsgr2vr_d(vi0, 0, 1); + const __m128i vi1 = __lsx_vld((const __m128i*)(Input + 8), 0); + __lsx_vinsgr2vr_d(vi1, 0, 1); + const __m128i vi2 = __lsx_vld((const __m128i*)(Input + 16), 0); + __lsx_vinsgr2vr_d(vi2, 0, 1); + const __m128i vi3 = __lsx_vld((const __m128i*)(Input + 24), 0); + __lsx_vinsgr2vr_d(vi3, 0, 1); + + if constexpr (std::is_signed::value) { + + const __m128i vxi0 = __lsx_vsrai_h(__lsx_vilvl_b(vi0, vzero), 8); + const __m128i vxi1 = __lsx_vsrai_h(__lsx_vilvl_b(vi1, vzero), 8); + const __m128i vxi2 = __lsx_vsrai_h(__lsx_vilvl_b(vi2, vzero), 8); + const __m128i vxi3 = __lsx_vsrai_h(__lsx_vilvl_b(vi3, vzero), 8); + const __m128i vsum = __lsx_vadd_h(__lsx_vadd_h(vxi0, vxi1), + __lsx_vadd_h(vxi2, vxi3)); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); + } else { + + const __m128i vxi0 = __lsx_vilvl_b(vzero, vi0); + const __m128i vxi1 = __lsx_vilvl_b(vzero, vi1); + const __m128i vxi2 = __lsx_vilvl_b(vzero, vi2); + const __m128i vxi3 = __lsx_vilvl_b(vzero, vi3); + const __m128i vsum = __lsx_vadd_h(__lsx_vadd_h(vxi0, vxi1), + __lsx_vadd_h(vxi2, vxi3)); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); + } + + Input += 32; + } + for (; Len >= 8; Len -= 8) { + + if constexpr (std::is_signed::value) { + + const __m128i vsum = __lsx_vsrai_h(__lsx_vilvl_b(__lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)Input, 0), 0, 1), vzero), 8); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); + } else { + + const __m128i vsum = __lsx_vilvl_b(vzero, __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)Input, 0), 0, 1)); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); + } + + Input += 8; + } + if (Len > 0) { + + memcpy(buffer, Input, Len); + + if constexpr (std::is_signed::value) { + + const __m128i vsum = __lsx_vsrai_h(__lsx_vilvl_b(__lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)buffer, 0), 0, 1), vzero), 8); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); + } else { + + const __m128i vsum = __lsx_vilvl_b(vzero, __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)buffer, 0), 0, 1)); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); + } + + Input += Len; + } + + __m128i vacc = __lsx_vadd_w(vacc_lo, vacc_hi); // [ D C | B A ] + __m128i vshuf = __lsx_vshuf4i_w(vacc, 0xb1); // [ C D | A B ] _MM_SHUFFLE(2, 3, 0, 1) + __m128i vsums = __lsx_vadd_w(vacc, vshuf); // [ D+C C+D | B+A A+B ] + vshuf = __lsx_vshuf4i_w(vsums, 0x4e); // [ B+A A+B | D+C C+D ] _MM_SHUFFLE(1, 0, 3, 2) + vsums = __lsx_vadd_w(vsums, vshuf); + __lsx_vstelm_w(vsums, sum_buffer++, 0 , 0); + } + + MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &scale, false, + static_cast(ZeroPointOutput), 0, 0, 1, Channels); +} + +template +MLAS_FORCEINLINE +void +MlasQLinearGlobalAveragePoolNhwcSingleBatch( + const T8Bits* Input, + T8Bits* Output, + const T8Bits* LastOf8, + size_t ImageSize, + size_t Channels, + size_t Stride, + int32_t Bias, + float Scale, + T8Bits Output_zero_point, + int32_t* AccumulateBuffer, + const T8Bits* ZeroBuffer + ) +{ + + constexpr size_t PixelsPerIteration = 7; +#define LOAD_FULL_CHANNELS() \ + const __m128i vi0 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i0, 0), 0 , 1); \ + i0 += 8; \ + const __m128i vi1 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i1, 0), 0 , 1); \ + i1 += 8; \ + const __m128i vi2 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i2, 0), 0 , 1); \ + i2 += 8; \ + const __m128i vi3 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i3, 0), 0 , 1); \ + i3 += 8; \ + const __m128i vi4 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i4, 0), 0 , 1); \ + i4 += 8; \ + const __m128i vi5 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i5, 0), 0 , 1); \ + i5 += 8; \ + const __m128i vi6 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i6, 0), 0 , 1); \ + i6 += 8 + +#define CALCULATE_ACCUMULATE_VECTORS() \ + __m128i vacc_lo = finish_one_pass ? __lsx_vld((__m128i*)acc, 0) : vbias; \ + __m128i vacc_hi = finish_one_pass ? __lsx_vld(((__m128i*)acc) + 1, 0) : vbias; \ + __m128i vxi0; \ + __m128i vxi1; \ + __m128i vxi2; \ + __m128i vxi3; \ + __m128i vxi4; \ + __m128i vxi5; \ + __m128i vxi6; \ + if constexpr (std::is_signed::value) { \ + vxi0 = __lsx_vsrai_h(__lsx_vilvl_b(vi0, vzero), 8); \ + vxi1 = __lsx_vsrai_h(__lsx_vilvl_b(vi1, vzero), 8); \ + vxi2 = __lsx_vsrai_h(__lsx_vilvl_b(vi2, vzero), 8); \ + vxi3 = __lsx_vsrai_h(__lsx_vilvl_b(vi3, vzero), 8); \ + vxi4 = __lsx_vsrai_h(__lsx_vilvl_b(vi4, vzero), 8); \ + vxi5 = __lsx_vsrai_h(__lsx_vilvl_b(vi5, vzero), 8); \ + vxi6 = __lsx_vsrai_h(__lsx_vilvl_b(vi6, vzero), 8); \ + } else { \ + vxi0 = __lsx_vilvl_b(vzero, vi0); \ + vxi1 = __lsx_vilvl_b(vzero, vi1); \ + vxi2 = __lsx_vilvl_b(vzero, vi2); \ + vxi3 = __lsx_vilvl_b(vzero, vi3); \ + vxi4 = __lsx_vilvl_b(vzero, vi4); \ + vxi5 = __lsx_vilvl_b(vzero, vi5); \ + vxi6 = __lsx_vilvl_b(vzero, vi6); \ + } \ + const __m128i vsum01 = __lsx_vadd_h(vxi0, vxi1); \ + const __m128i vsum23 = __lsx_vadd_h(vxi2, vxi3); \ + const __m128i vsum45 = __lsx_vadd_h(vxi4, vxi5); \ + const __m128i vsum016 = __lsx_vadd_h(vsum01, vxi6); \ + const __m128i vsum2345 = __lsx_vadd_h(vsum23, vsum45); \ + const __m128i vsum = __lsx_vadd_h(vsum016, vsum2345); \ + if constexpr (std::is_signed::value) { \ + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); \ + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); \ + } else { \ + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); \ + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); \ + } + + + T8Bits tail[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + bool finish_one_pass = false; + const __m128i vbias = __lsx_vreplgr2vr_w(Bias); + const __m128i vzero = __lsx_vldi(0); + size_t step_next_group = PixelsPerIteration * Stride - (Channels & ~size_t{7}); + + const T8Bits* i0 = Input; + const T8Bits* i1 = i0 + Stride; + const T8Bits* i2 = i1 + Stride; + const T8Bits* i3 = i2 + Stride; + const T8Bits* i4 = i0 + Stride * 4; + const T8Bits* i5 = i4 + Stride; + const T8Bits* i6 = i5 + Stride; + + for (; ImageSize > PixelsPerIteration; ImageSize -= PixelsPerIteration) { + + int32_t* acc = AccumulateBuffer; + size_t c = Channels; + for (; c >= 8; c -= 8) { + + LOAD_FULL_CHANNELS(); + + CALCULATE_ACCUMULATE_VECTORS(); + + __lsx_vst(vacc_lo, (__m128i*)acc, 0); + __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); + acc += 8; + } + if (c > 0) { + const __m128i vi0 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i0 >= LastOf8 ? memcpy(tail, i0, c) : i0), 0), 0 ,1); + const __m128i vi1 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i1 >= LastOf8 ? memcpy(tail, i1, c) : i1), 0), 0 ,1); + const __m128i vi2 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i2 >= LastOf8 ? memcpy(tail, i2, c) : i2), 0), 0 ,1); + const __m128i vi3 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i3 >= LastOf8 ? memcpy(tail, i3, c) : i3), 0), 0 ,1); + const __m128i vi4 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i4 >= LastOf8 ? memcpy(tail, i4, c) : i4), 0), 0 ,1); + const __m128i vi5 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i5 >= LastOf8 ? memcpy(tail, i5, c) : i5), 0), 0 ,1); + const __m128i vi6 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i6 >= LastOf8 ? memcpy(tail, i6, c) : i6), 0), 0 ,1); + + CALCULATE_ACCUMULATE_VECTORS(); + + __lsx_vst(vacc_lo, (__m128i*)acc, 0); + __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); + } + finish_one_pass = true; + + i0 += step_next_group; + i1 += step_next_group; + i2 += step_next_group; + i3 += step_next_group; + i4 += step_next_group; + i5 += step_next_group; + i6 += step_next_group; + } + + if (ImageSize > 0) { + switch (ImageSize) { + case 1: + i1 = ZeroBuffer; + [[fallthrough]]; + case 2: + i2 = ZeroBuffer; + [[fallthrough]]; + case 3: + i3 = ZeroBuffer; + [[fallthrough]]; + case 4: + i4 = ZeroBuffer; + [[fallthrough]]; + case 5: + i5 = ZeroBuffer; + [[fallthrough]]; + case 6: + i6 = ZeroBuffer; + [[fallthrough]]; + default: + break; + } + + int32_t* acc = AccumulateBuffer; + size_t c = Channels; + for (; c >= 8; c -= 8) { + + LOAD_FULL_CHANNELS(); + + CALCULATE_ACCUMULATE_VECTORS(); + + __lsx_vst(vacc_lo, (__m128i*)acc, 0); + __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); + acc += 8; + } + + if (c > 0) { + const __m128i vi0 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i0 >= LastOf8 ? memcpy(tail, i0, c) : i0), 0), 0 ,1); + const __m128i vi1 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(1 < ImageSize && i1 >= LastOf8 ? memcpy(tail, i1, c) : i1), 0), 0, 1); + const __m128i vi2 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(2 < ImageSize && i2 >= LastOf8 ? memcpy(tail, i2, c) : i2), 0), 0, 1); + const __m128i vi3 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(3 < ImageSize && i3 >= LastOf8 ? memcpy(tail, i3, c) : i3), 0), 0, 1); + const __m128i vi4 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(4 < ImageSize && i4 >= LastOf8 ? memcpy(tail, i4, c) : i4), 0), 0, 1); + const __m128i vi5 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(5 < ImageSize && i5 >= LastOf8 ? memcpy(tail, i5, c) : i5), 0), 0, 1); + const __m128i vi6 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(6 < ImageSize && i6 >= LastOf8 ? memcpy(tail, i6, c) : i6), 0), 0, 1); + + CALCULATE_ACCUMULATE_VECTORS(); + + __lsx_vst(vacc_lo, (__m128i*)acc, 0); + __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); + } + } + MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &Scale, false, + Output_zero_point, 0, 0, 1, Channels); +} + #else // Pure C++ Implementation @@ -771,7 +1081,7 @@ MlasQLinearGlobalAveragePoolNhwc( #endif -#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) +#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) template void diff --git a/onnxruntime/core/mlas/lib/qlmul.cpp b/onnxruntime/core/mlas/lib/qlmul.cpp index 4b8537f2b378f..38818e1190d21 100644 --- a/onnxruntime/core/mlas/lib/qlmul.cpp +++ b/onnxruntime/core/mlas/lib/qlmul.cpp @@ -377,6 +377,170 @@ MlasQLinearMulKernel( MLAS_UNREFERENCED_PARAMETER(ValueBVector); } +#elif defined(MLAS_LSX_INTRINSICS) + +template +MLAS_FORCEINLINE +static +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ); + +template <> +MLAS_FORCEINLINE +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ) +{ + return __lsx_vilvl_b(ZeroVector, Int8Vector); +} + +template <> +MLAS_FORCEINLINE +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ) +{ + return __lsx_vilvh_b(ZeroVector, Int8Vector); +} + +template <> +MLAS_FORCEINLINE +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ) +{ + MLAS_UNREFERENCED_PARAMETER(ZeroVector); + return __lsx_vsrai_h(__lsx_vilvl_b(Int8Vector, Int8Vector), 8); +} + +template <> +MLAS_FORCEINLINE +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ) +{ + MLAS_UNREFERENCED_PARAMETER(ZeroVector); + return __lsx_vsrai_h(__lsx_vilvh_b(Int8Vector, Int8Vector), 8); +} + +template +MLAS_FORCEINLINE +static +__m128i +MlasExtendToS16Debias( + __m128i Int8Vector, + __m128i ZeroVector, + __m128i VectorBias + ) +{ + return __lsx_vsub_h(MlasExtendToS16(Int8Vector, ZeroVector), VectorBias); +} + +MLAS_FORCEINLINE +static +__m128i +MlasQLinearMulVectorS16( + __m128i va_s16x8, + __m128i vb_s16x8, + __m128 VectorScaleRatio, + __m128 VectorZeroPointC + ) +{ + __m128i tmp, tmp1; + + const auto ab_lo = __lsx_vmul_h(va_s16x8, vb_s16x8); + const auto ab_hi = __lsx_vmuh_h(va_s16x8, vb_s16x8); + auto r_lo = __lsx_vilvl_h(ab_hi, ab_lo); + auto r_hi = __lsx_vilvh_h(ab_hi, ab_lo); + r_lo = __lsx_vftint_w_s(__lsx_vfmadd_s(__lsx_vffint_s_w(r_lo), VectorScaleRatio, VectorZeroPointC)); + r_hi = __lsx_vftint_w_s(__lsx_vfmadd_s(__lsx_vffint_s_w(r_hi), VectorScaleRatio, VectorZeroPointC)); + + tmp = __lsx_vsat_w(r_lo, 15); + tmp1 = __lsx_vsat_w(r_hi, 15); + return __lsx_vpickev_h(tmp1, tmp); +} + +template +static +void +MlasQLinearMulKernel( + const DataType* InputA, + float ScaleA, + int32_t ZeroPointA, + const DataType* InputB, + float ScaleB, + int32_t ZeroPointB, + float ScaleC, + int32_t ZeroPointC, + DataType* OutputC, + size_t N + ) +{ + const auto VectorZeroPointA = __lsx_vreplgr2vr_h((int16_t)ZeroPointA); + const auto VectorZeroPointB = __lsx_vreplgr2vr_h((int16_t)ZeroPointB); + const auto VectorZeroPointC = MlasBroadcastFloat32x4((float)ZeroPointC); + const auto VectorScaleRatio = MlasBroadcastFloat32x4(ScaleA * ScaleB / ScaleC); + const auto ZeroVector = __lsx_vldi(0); + + uint8_t TailDataA[16] = { 0 }; + uint8_t TailDataB[16] = { 0 }; + __m128i vb_lo_s16x8, vb_hi_s16x8; + + if (IsScalarB) { + vb_lo_s16x8 = __lsx_vsub_h(__lsx_vreplgr2vr_h((int16_t)*InputB), VectorZeroPointB); + vb_hi_s16x8 = vb_lo_s16x8; + } + + while (N > 0) { + if (N < 16) { + MlasCopyTailBytes(TailDataA, (const uint8_t*)InputA, N); + InputA = (const DataType*)TailDataA; + if (!IsScalarB) { + MlasCopyTailBytes(TailDataB, (const uint8_t*)InputB, N); + InputB = (const DataType*)TailDataB; + } + } + + const auto va_i8x16 = __lsx_vld((const MLAS_INT32X4*)InputA, 0); + InputA += 16; + const auto va_lo_s16x8 = MlasExtendToS16Debias(va_i8x16, ZeroVector, VectorZeroPointA); + const auto va_hi_s16x8 = MlasExtendToS16Debias(va_i8x16, ZeroVector, VectorZeroPointA); + + if (!IsScalarB) { + const auto vb_i8x16 = __lsx_vld((const MLAS_INT32X4*)InputB, 0); + InputB += 16; + vb_lo_s16x8 = MlasExtendToS16Debias(vb_i8x16, ZeroVector, VectorZeroPointB); + vb_hi_s16x8 = MlasExtendToS16Debias(vb_i8x16, ZeroVector, VectorZeroPointB); + } + + const auto vc_lo_s16x8 = MlasQLinearMulVectorS16(va_lo_s16x8, vb_lo_s16x8, VectorScaleRatio, VectorZeroPointC); + const auto vc_hi_s16x8 = MlasQLinearMulVectorS16(va_hi_s16x8, vb_hi_s16x8, VectorScaleRatio, VectorZeroPointC); + auto vc = MlasPackS16_128(vc_lo_s16x8, vc_hi_s16x8); + + if (N >= 16) { + __lsx_vst(vc, (__m128i*)OutputC, 0); + OutputC += 16; + N -= 16; + } else { + __lsx_vst(vc, (__m128i*)TailDataA, 0); + MlasCopyTailBytes((uint8_t*)OutputC, TailDataA, N); + N = 0; + } + } +} + + #else // Pure C++ implementation. diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index 133ad79594c55..ffecc2dbeff9e 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -20,7 +20,9 @@ Module Name: #include "mlasi.h" -#if defined(MLAS_NEON64_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) +#if defined(MLAS_NEON64_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) || \ + defined(MLAS_LSX_INTRINSICS) + #include // @@ -49,6 +51,9 @@ MlasQuantizeLinearVector( // is a NaN. FloatVector = vmaxnmq_f32(FloatVector, MinimumValueVector); FloatVector = vminnmq_f32(FloatVector, MaximumValueVector); +#elif defined(MLAS_LSX_INTRINSICS) + FloatVector = __lsx_vfmax_s(FloatVector, MinimumValueVector); + FloatVector = __lsx_vfmin_s(FloatVector, MaximumValueVector); #else // N.B. MINPS and MAXPS returns the value from the second vector if the // value from the first vector is a NaN. @@ -64,6 +69,9 @@ MlasQuantizeLinearVector( #if defined(MLAS_NEON64_INTRINSICS) auto IntegerVector = vcvtnq_s32_f32(FloatVector); IntegerVector = vaddq_s32(IntegerVector, ZeroPointVector); +#elif defined(MLAS_LSX_INTRINSICS) + auto IntegerVector = __lsx_vftint_w_s(FloatVector); + IntegerVector = __lsx_vadd_w(IntegerVector, ZeroPointVector); #else // N.B. Assumes MXCSR has been configured with the default rounding mode of // "round to nearest even". @@ -213,6 +221,121 @@ MlasQuantizeLinearStoreSingleValue( vst1q_lane_s16(Output, vreinterpretq_s16_s32(IntegerVector), 0); } +#elif defined(MLAS_LSX_INTRINSICS) +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 integervector + ) +{ + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2; + + tmp = __lsx_vmax_h(integervector, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + integervector = __lsx_vpickev_b(tmp2, tmp2); + + + tmp = __lsx_vmax_h(integervector, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + integervector = __lsx_vpickev_b(tmp2, tmp2); + return integervector; +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 integervector + ) +{ + + __m128i tmp, tmp1; + + tmp = __lsx_vsat_h(integervector, 7); + tmp1 = __lsx_vsat_h(integervector, 7); + integervector = __lsx_vpickev_b(tmp1, tmp); + + tmp = __lsx_vsat_h(integervector, 7); + tmp1 = __lsx_vsat_h(integervector, 7); + integervector = __lsx_vpickev_b(tmp1, tmp); + return integervector; +} + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStore4PackedValues( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + // Copies the lower 4 packed elements of the vector into memory (Output). + + if constexpr (std::is_same_v || std::is_same_v) { + __lsx_vstelm_w(IntegerVector, reinterpret_cast(Output), 0, 0); + } else { + static_assert(std::is_same_v || std::is_same_v); + + __lsx_vstelm_d(IntegerVector, reinterpret_cast(Output), 0, 0); + } +} + + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v); + + // Copies the lower element of the vector into memory (Output). + // Expects that the 32-bit element in lane 0 is already within the valid numerical + // range of the OutputType. + *Output = static_cast(__lsx_vpickve2gr_w(IntegerVector, 0)); +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2; + + tmp = __lsx_vmax_w(IntegerVector, zero); + tmp2 = __lsx_vsat_wu(tmp, 15); + + IntegerVector = __lsx_vpickev_h(tmp2, tmp2); + return IntegerVector; +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + __m128i tmp, tmp1; + + tmp = __lsx_vsat_w(IntegerVector, 15); + tmp1 = __lsx_vsat_w(IntegerVector, 15); + IntegerVector = __lsx_vpickev_h(tmp1, tmp); + return IntegerVector; +} #else template<> @@ -384,6 +507,8 @@ Return Value: #if defined(MLAS_NEON64_INTRINSICS) auto FloatVector = vld1q_dup_f32(Input + n); +#elif defined(MLAS_LSX_INTRINSICS) + MLAS_FLOAT32X4 FloatVector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input+n, 0); #else auto FloatVector = _mm_load_ss(Input + n); #endif @@ -1362,6 +1487,286 @@ MlasRequantizeOutput( } } +#elif defined(MLAS_LSX_INTRINSICS) + +template +void +MlasRequantizeOutput( + const int32_t* Input, + size_t InputLeadingDimension, + OutputType* Output, + size_t OutputLeadingDimension, + const int32_t* Bias, + const float* Scale, + bool PerColumnScale, + OutputType ZeroPoint, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN + ) +{ + //TO BE CHECK + float min_f = float(std::numeric_limits::lowest() - ZeroPoint); + float max_f = float(std::numeric_limits::max() - ZeroPoint); + const __m128 PerMatrixScaleVector = PerColumnScale ? MlasReinterpretAsFloat32x4(__lsx_vldi(0)) : MlasReinterpretAsFloat32x4(__lsx_vldrepl_w(Scale, 0)); + const __m128 MinimumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&min_f))); + const __m128 MaximumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&max_f))); + const __m128i ZeroPointVector = __lsx_vreplgr2vr_w(ZeroPoint); + + if (nullptr != Bias) { + Bias += StartN; + } + if (PerColumnScale) { + Scale += StartN; + } + + Input += StartM * InputLeadingDimension + StartN; + Output += StartM * OutputLeadingDimension + StartN; + // + // Step through each row of the output matrix. + // + + while (CountM-- > 0) { + + const int32_t* bias = Bias; + const float* scale = PerColumnScale ? Scale : nullptr; + size_t n = CountN; + + auto* RowInput = Input; + auto* RowOutput = Output; + + // + // Process 16 columns of the matrices at a time. + // + + while (n >= 16) { + + // + // Load the input data and optionally add the per-column bias. + // + + __m128i IntegerVector0 = __lsx_vld((const __m128i*)&RowInput[0], 0); + __m128i IntegerVector1 = __lsx_vld((const __m128i*)&RowInput[4], 0); + __m128i IntegerVector2 = __lsx_vld((const __m128i*)&RowInput[8], 0); + __m128i IntegerVector3 = __lsx_vld((const __m128i*)&RowInput[12], 0); + RowInput += 16; + + if (bias != nullptr) { + IntegerVector0 = __lsx_vadd_w(IntegerVector0, __lsx_vld((const __m128i *)&bias[0], 0)); + IntegerVector1 = __lsx_vadd_w(IntegerVector1, __lsx_vld((const __m128i *)&bias[4], 0)); + IntegerVector2 = __lsx_vadd_w(IntegerVector2, __lsx_vld((const __m128i *)&bias[8], 0)); + IntegerVector3 = __lsx_vadd_w(IntegerVector3, __lsx_vld((const __m128i *)&bias[12], 0)); + bias += 16; + } + + // + // Convert to integer values to float and apply the per-tensor or + // per-column scaling. + // + + __m128 FloatVector0 = __lsx_vffint_s_w(IntegerVector0); + __m128 FloatVector1 = __lsx_vffint_s_w(IntegerVector1); + __m128 FloatVector2 = __lsx_vffint_s_w(IntegerVector2); + __m128 FloatVector3 = __lsx_vffint_s_w(IntegerVector3); + + if (scale != nullptr) { + + FloatVector0 = __lsx_vfmul_s(FloatVector0, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[0], 0))); + FloatVector1 = __lsx_vfmul_s(FloatVector1, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[4], 0))); + FloatVector2 = __lsx_vfmul_s(FloatVector2, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[8], 0))); + FloatVector3 = __lsx_vfmul_s(FloatVector3, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[12], 0))); + scale += 16; + + } else { + + FloatVector0 = __lsx_vfmul_s(FloatVector0, PerMatrixScaleVector); + FloatVector1 = __lsx_vfmul_s(FloatVector1, PerMatrixScaleVector); + FloatVector2 = __lsx_vfmul_s(FloatVector2, PerMatrixScaleVector); + FloatVector3 = __lsx_vfmul_s(FloatVector3, PerMatrixScaleVector); + } + FloatVector0 = __lsx_vfmax_s(FloatVector0, MinimumValueVector); + FloatVector1 = __lsx_vfmax_s(FloatVector1, MinimumValueVector); + FloatVector2 = __lsx_vfmax_s(FloatVector2, MinimumValueVector); + FloatVector3 = __lsx_vfmax_s(FloatVector3, MinimumValueVector); + + FloatVector0 = __lsx_vfmin_s(FloatVector0, MaximumValueVector); + FloatVector1 = __lsx_vfmin_s(FloatVector1, MaximumValueVector); + FloatVector2 = __lsx_vfmin_s(FloatVector2, MaximumValueVector); + FloatVector3 = __lsx_vfmin_s(FloatVector3, MaximumValueVector); + + IntegerVector0 = __lsx_vftint_w_s(FloatVector0); + IntegerVector1 = __lsx_vftint_w_s(FloatVector1); + IntegerVector2 = __lsx_vftint_w_s(FloatVector2); + IntegerVector3 = __lsx_vftint_w_s(FloatVector3); + + IntegerVector0 = __lsx_vadd_w(IntegerVector0, ZeroPointVector); + IntegerVector1 = __lsx_vadd_w(IntegerVector1, ZeroPointVector); + IntegerVector2 = __lsx_vadd_w(IntegerVector2, ZeroPointVector); + IntegerVector3 = __lsx_vadd_w(IntegerVector3, ZeroPointVector); + + __m128i WordVector0; + __m128i WordVector1; + __m128i ByteVector; + + if (std::is_signed::value) { + + __m128i tmp, tmp1; + tmp = __lsx_vsat_w(IntegerVector0, 15); + tmp1 = __lsx_vsat_w(IntegerVector1, 15); + WordVector0 = __lsx_vpickev_h(tmp1, tmp); + + tmp = __lsx_vsat_w(IntegerVector2, 15); + tmp1 = __lsx_vsat_w(IntegerVector3, 15); + WordVector1 = __lsx_vpickev_h(tmp1, tmp); + + tmp = __lsx_vsat_h(WordVector0, 7); + tmp1 = __lsx_vsat_h(WordVector1, 7); + ByteVector = __lsx_vpickev_b(tmp1, tmp); + + + } else { + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2, tmp3; + + tmp = __lsx_vmax_h(IntegerVector0, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(IntegerVector1, zero); + tmp3 = __lsx_vsat_hu(tmp, 7); + WordVector0 = __lsx_vpickev_b(tmp3, tmp2); + + tmp = __lsx_vmax_h(IntegerVector2, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(IntegerVector3, zero); + tmp3 = __lsx_vsat_hu(tmp, 7); + WordVector1 = __lsx_vpickev_b(tmp3, tmp2); + + tmp = __lsx_vmax_h(WordVector0, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(WordVector1, zero); + tmp3 = __lsx_vsat_hu(tmp, 7); + ByteVector = __lsx_vpickev_b(tmp3, tmp2); + + } + + __lsx_vst(ByteVector, (__m128i*)RowOutput, 0); + RowOutput += 16; + + n -= 16; + } + + // + // Process the remaining columns of the matrices. + // + + while (n > 0) { + + // + // Load the input data and optionally add the per-column bias. + // + + __m128i IntegerVector; + + if (n >= 4) { + + IntegerVector = __lsx_vld((const __m128i*)&RowInput[0], 0); + RowInput += 4; + + if (bias != nullptr) { + IntegerVector = __lsx_vadd_w(IntegerVector, __lsx_vld((const __m128i*)&bias[0], 0)); + bias += 4; + } + + } else { + + int32_t IntegerValue = *RowInput++; + + if (bias != nullptr) { + IntegerValue += *bias++; + } + IntegerVector = __lsx_vldrepl_w(&IntegerValue, 0); + } + + // + // Convert to integer values to float and apply the per-tensor or + // per-column scaling. + // + __m128 FloatVector = __lsx_vffint_s_w(IntegerVector); + __m128 ScaleVector; + + if (scale != nullptr) { + + if (n >= 4) { + ScaleVector = MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)scale, 0)); + scale += 4; + } else { + ScaleVector = (__m128)__lsx_vldrepl_w(scale, 0); + scale += 1; + } + + } else { + ScaleVector = PerMatrixScaleVector; + } + FloatVector = __lsx_vfmul_s(FloatVector, ScaleVector); + + FloatVector = __lsx_vfmax_s(FloatVector, MinimumValueVector); + FloatVector = __lsx_vfmin_s(FloatVector, MaximumValueVector); + + IntegerVector = __lsx_vftint_w_s(FloatVector); + IntegerVector = __lsx_vadd_w(IntegerVector, ZeroPointVector); + + if (std::is_signed::value) { + + __m128i tmp; + tmp = __lsx_vsat_w(IntegerVector, 15); + IntegerVector = __lsx_vpickev_h(tmp, tmp); + + tmp = __lsx_vsat_h(IntegerVector, 7); + IntegerVector = __lsx_vpickev_b(tmp, tmp); + + } else { + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2; + + tmp = __lsx_vmax_h(IntegerVector, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + IntegerVector = __lsx_vpickev_b(tmp2, tmp2); + + tmp = __lsx_vmax_h(IntegerVector, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + IntegerVector = __lsx_vpickev_b(tmp2, tmp2); + + } + + uint32_t OutputValue = uint32_t(__lsx_vpickve2gr_w(IntegerVector, 0)); + + if (n >= 4) { + + *reinterpret_cast(RowOutput) = OutputValue; + RowOutput += 4; + + n -= 4; + + } else { + + *RowOutput = uint8_t(OutputValue); + RowOutput += 1; + + n -= 1; + } + } + + // Next Row + Input += InputLeadingDimension; + Output += OutputLeadingDimension; + } +} + #else template diff --git a/onnxruntime/core/mlas/lib/reorder.cpp b/onnxruntime/core/mlas/lib/reorder.cpp index 99c1dbac3b692..b329ea2ffb149 100644 --- a/onnxruntime/core/mlas/lib/reorder.cpp +++ b/onnxruntime/core/mlas/lib/reorder.cpp @@ -180,6 +180,31 @@ Return Value: v[2] = _mm_movelh_ps(t[2], t[3]); v[3] = _mm_movehl_ps(t[3], t[2]); + MlasStoreFloat32x4(&D[ScatterStride * 0], v[0]); + MlasStoreFloat32x4(&D[ScatterStride * 1], v[1]); + MlasStoreFloat32x4(&D[ScatterStride * 2], v[2]); + MlasStoreFloat32x4(&D[ScatterStride * 3], v[3]); +#elif defined(MLAS_LSX_INTRINSICS) + + MLAS_FLOAT32X4 v[4]; + MLAS_FLOAT32X4 t[4]; + + v[0] = MlasLoadFloat32x4(&S[GatherStride * 0]); + v[1] = MlasLoadFloat32x4(&S[GatherStride * 1]); + v[2] = MlasLoadFloat32x4(&S[GatherStride * 2]); + v[3] = MlasLoadFloat32x4(&S[GatherStride * 3]); + + t[0] = (__m128)__lsx_vilvl_w((__m128i)v[1], (__m128i)v[0]); + t[2] = (__m128)__lsx_vilvh_w((__m128i)v[1], (__m128i)v[0]); + t[1] = (__m128)__lsx_vilvl_w((__m128i)v[3], (__m128i)v[2]); + t[3] = (__m128)__lsx_vilvh_w((__m128i)v[3], (__m128i)v[2]); + + + v[0] = (__m128)__lsx_vpickev_d((__m128i) t[1],(__m128i) t[0]); + v[1] = (__m128)__lsx_vpickod_d((__m128i) t[1],(__m128i) t[0]); + v[2] = (__m128)__lsx_vpickev_d((__m128i) t[3],(__m128i) t[2]); + v[3] = (__m128)__lsx_vpickod_d((__m128i) t[3],(__m128i) t[2]); + MlasStoreFloat32x4(&D[ScatterStride * 0], v[0]); MlasStoreFloat32x4(&D[ScatterStride * 1], v[1]); MlasStoreFloat32x4(&D[ScatterStride * 2], v[2]); @@ -456,7 +481,6 @@ Return Value: &TaskStart, &TasksRemaining); size_t TaskEnd = TaskStart + TasksRemaining; - // // Rebase the pointers to the source and destination buffers for this thread. // @@ -567,18 +591,17 @@ Return Value: WorkBlock.S = S; WorkBlock.D = D; - WorkBlock.OutputChannels = size_t(OutputShape[1]); WorkBlock.OutputSize = size_t(OutputShape[2]) * size_t(OutputShape[3]); const size_t BlockSize = MlasNchwcGetBlockSize(); const size_t TasksPerBatch = size_t(ceil(((float)WorkBlock.OutputChannels) / BlockSize)); const size_t BatchCount = size_t(OutputShape[0]); - const size_t TasksCount = BatchCount * TasksPerBatch; + const size_t TasksCount = BatchCount * TasksPerBatch; WorkBlock.TasksCount = TasksCount; // - // Schedule the operation across a set of worker threads if the output + // Schedule the operation across a set of worker threads if the output // tensor is sufficienly large. Limit the number of threads to at least // the number of available tasks. // @@ -590,7 +613,7 @@ Return Value: if (size_t(TargetThreadCount) > TasksCount) { TargetThreadCount = ptrdiff_t(TasksCount); } - } + } WorkBlock.TargetThreadCount = TargetThreadCount; MlasExecuteThreaded(MlasReorderOutputNchwThreaded, &WorkBlock, TargetThreadCount, ThreadPool); diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 1ce64712d63dc..4d7a1ceb4eee7 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -472,7 +472,7 @@ Return Value: const float* b = B; size_t x = CountX; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* SgemmTransposePackB16x4Routine = GetMlasPlatform().TransposePackB16x4Routine; @@ -1061,7 +1061,7 @@ Return Value: size_t RowsHandled; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp index 74d65f934aaf5..f9cf1605787aa 100644 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ b/onnxruntime/core/mlas/lib/snchwc.cpp @@ -101,7 +101,7 @@ Return Value: --*/ { -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) return GetMlasPlatform().NchwcBlockSize; #else return 1; @@ -674,7 +674,7 @@ struct MLAS_NCHWC_CONV_NCHWC_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwcFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwcFloatKernel; @@ -784,7 +784,7 @@ struct MLAS_NCHWC_CONV_NCHW_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwFloatKernel; @@ -879,7 +879,7 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t FilterStrideBytes = BlockSize * InputChannels * sizeof(float); const size_t OutputStrideBytes = BlockSize * OutputSize * sizeof(float); -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; #else MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel; @@ -1016,7 +1016,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvDepthwiseFloatKernel; #else MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = MlasConvDepthwiseFloatKernel; @@ -1093,7 +1093,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM { -#if !defined(MLAS_TARGET_AMD64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) static MLAS_POOL_FLOAT_KERNEL* const PoolKernels[]; #endif @@ -1131,7 +1131,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM const size_t DilatedInputWidthBytes = BlockSize * DilationHeight * InputWidth * sizeof(float); const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_POOL_FLOAT_KERNEL* Kernel = GetMlasPlatform().PoolFloatKernel[WorkBlock->PoolingKind]; #else MLAS_POOL_FLOAT_KERNEL* Kernel = PoolKernels[WorkBlock->PoolingKind]; @@ -1197,7 +1197,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM } }; -#if !defined(MLAS_TARGET_AMD64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) MLAS_POOL_FLOAT_KERNEL* const MLAS_NCHWC_POOL_ALGORITHM::PoolKernels[] = { @@ -1621,7 +1621,7 @@ Return Value: } } -#if !defined(MLAS_TARGET_AMD64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) // // Convolution and pooling kernel stubs for architectures that do not yet have diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index f964b1affec31..7f1d1b084aec0 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -15,6 +15,9 @@ Module Name: --*/ #include "sqnbitgemm.h" +#ifdef MLAS_JBLAS +#include "jblas_gemm.h" +#endif namespace { @@ -142,3 +145,127 @@ MlasIsSQNBitGemmAvailable( return true; } + +size_t MLASCALL +MlasNBitsGemmPackBSize( + size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType +) +{ +#ifdef MLAS_JBLAS + if (nbits == 4) { + auto jsize = JblasQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); + if (jsize) { + return jsize; + } + } +#endif + (void)(N); + (void)(K); + (void)(BlkSize); + (void)(nbits); + (void)(isAsym); + (void)(CompType); + return 0; +} + +void MLASCALL +MlasNBitsGemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + int nbits, + bool isAsym, + bool lastCall, + MLAS_SQNBIT_COMPUTE_TYPE CompType, + MLAS_THREADPOOL* ThreadPool +) +{ +#ifdef MLAS_JBLAS + if (nbits == 4) { + if (JblasQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { + return; + } + } +#endif + (void)(PackedBuf); + (void)(QData); + (void)(Scale); + (void)(Zp); + (void)(N); + (void)(K); + (void)(ldb); + (void)(BlkSize); + (void)(nbits); + (void)(isAsym); + (void)(lastCall); + (void)(CompType); + (void)(ThreadPool); +} + +void MLASCALL +MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) +{ +#ifdef MLAS_JBLAS + if (JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { + return; + } +#endif + (void)(FpData); + (void)(PackedBuf); + (void)(N); + (void)(K); + (void)(ldb); + (void)(ThreadPool); +} + +size_t MLASCALL +MlasSQNBitsGemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +) +{ +#ifdef MLAS_JBLAS + return JblasSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); +#endif + (void)(M); + (void)(N); + (void)(K); + (void)(BatchN); + (void)(DataParams); + return 0; +} + +void MLASCALL +MlasSQNBitsGemmBatchPackedB( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + void* WorkSpace, + MLAS_THREADPOOL* ThreadPool +) +{ + GetMlasPlatform(); +#ifdef MLAS_JBLAS + if (JblasSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { + // PackedWeight is created by jblas + return; + } +#endif + (void)(M); + (void)(N); + (void)(K); + (void)(BatchN); + (void)(DataParams); + (void)(WorkSpace); + (void)(ThreadPool); +} diff --git a/onnxruntime/core/mlas/lib/transpose.cpp b/onnxruntime/core/mlas/lib/transpose.cpp index 86b0897bb91ec..a758a0e59fb4f 100644 --- a/onnxruntime/core/mlas/lib/transpose.cpp +++ b/onnxruntime/core/mlas/lib/transpose.cpp @@ -371,6 +371,121 @@ MlasTranspose16x16Block( vec_vsx_st(e0, 0, &Output[OutputStride * 14]); vec_vsx_st(e1, 0, &Output[OutputStride * 15]); } + +#elif defined(MLAS_LSX_INTRINSICS) + +MLAS_FORCEINLINE +void +MlasTranspose4x4Block( + const uint32_t* Input, + size_t InputStride, + uint32_t* Output, + size_t OutputStride + ) +{ + __m128i a0 = __lsx_vld((const __m128i*)&Input[InputStride * 0], 0); + __m128i a1 = __lsx_vld((const __m128i*)&Input[InputStride * 1], 0); + __m128i a2 = __lsx_vld((const __m128i*)&Input[InputStride * 2], 0); + __m128i a3 = __lsx_vld((const __m128i*)&Input[InputStride * 3], 0); + + __m128i b0 = __lsx_vilvl_w(a2, a0); + __m128i b1 = __lsx_vilvh_w(a2, a0); + __m128i b2 = __lsx_vilvl_w(a3, a1); + __m128i b3 = __lsx_vilvh_w(a3, a1); + __m128i c0 = __lsx_vilvl_w(b2, b0); + __m128i c1 = __lsx_vilvh_w(b2, b0); + __m128i c2 = __lsx_vilvl_w(b3, b1); + __m128i c3 = __lsx_vilvh_w(b3, b1); + + __lsx_vst(c0, (__m128i*)&Output[OutputStride * 0], 0); + __lsx_vst(c1, (__m128i*)&Output[OutputStride * 1], 0); + __lsx_vst(c2, (__m128i*)&Output[OutputStride * 2], 0); + __lsx_vst(c3, (__m128i*)&Output[OutputStride * 3], 0); +} + +MLAS_FORCEINLINE +void +MlasTranspose4x4Block( + const uint16_t* Input, + size_t InputStride, + uint16_t* Output, + size_t OutputStride + ) +{ + __m128i a0 = __lsx_vld((const __m128i*)&Input[InputStride * 0], 0); + __lsx_vinsgr2vr_d(a0, 0 , 1); + __m128i a1 = __lsx_vld((const __m128i*)&Input[InputStride * 1], 0); + __lsx_vinsgr2vr_d(a1, 0 , 1); + __m128i a2 = __lsx_vld((const __m128i*)&Input[InputStride * 2], 0); + __lsx_vinsgr2vr_d(a2, 0 , 1); + __m128i a3 = __lsx_vld((const __m128i*)&Input[InputStride * 3], 0); + __lsx_vinsgr2vr_d(a3, 0 , 1); + + __m128i b0 = __lsx_vilvl_h(a2, a0); + __m128i b1 = __lsx_vilvl_h(a3, a1); + __m128i c0 = __lsx_vilvl_h(b1, b0); + __m128i c1 = __lsx_vilvh_h(b1, b0); + + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 0], 0), __lsx_vpickve2gr_d(c0, 0), 0), (__m128i *)&Output[OutputStride * 0], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 1], 0), __lsx_vpickve2gr_d(c0, 1), 0), (__m128i *)&Output[OutputStride * 1], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 2], 0), __lsx_vpickve2gr_d(c1, 0), 0), (__m128i *)&Output[OutputStride * 2], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 3], 0), __lsx_vpickve2gr_d(c1, 1), 0), (__m128i *)&Output[OutputStride * 3], 0); +} + +MLAS_FORCEINLINE +void +MlasTranspose8x8Block( + const uint8_t* Input, + size_t InputStride, + uint8_t* Output, + size_t OutputStride + ) +{ + __m128i a0 = __lsx_vld((const __m128i*)&Input[InputStride * 0], 0); + __lsx_vinsgr2vr_d(a0, 0, 1); + __m128i a1 = __lsx_vld((const __m128i*)&Input[InputStride * 1], 0); + __lsx_vinsgr2vr_d(a1, 0, 1); + __m128i b0 = __lsx_vilvl_b(a1, a0); + + __m128i a2 = __lsx_vld((const __m128i*)&Input[InputStride * 2], 0); + __lsx_vinsgr2vr_d(a2, 0, 1); + __m128i a3 = __lsx_vld((const __m128i*)&Input[InputStride * 3], 0); + __lsx_vinsgr2vr_d(a3, 0, 1); + __m128i b1 = __lsx_vilvl_b(a3, a2); + + __m128i a4 = __lsx_vld((const __m128i*)&Input[InputStride * 4], 0); + __lsx_vinsgr2vr_d(a4, 0, 1); + __m128i a5 = __lsx_vld((const __m128i*)&Input[InputStride * 5], 0); + __lsx_vinsgr2vr_d(a5, 0, 1); + __m128i b2 = __lsx_vilvl_b(a5, a4); + + __m128i a6 = __lsx_vld((const __m128i*)&Input[InputStride * 6], 0); + __lsx_vinsgr2vr_d(a6, 0, 1); + __m128i a7 = __lsx_vld((const __m128i*)&Input[InputStride * 7], 0); + __lsx_vinsgr2vr_d(a7, 0, 1); + __m128i b3 = __lsx_vilvl_b(a7, a6); + __m128i c0 = __lsx_vilvl_h(b1, b0); + __m128i c1 = __lsx_vilvh_h(b1, b0); + __m128i c2 = __lsx_vilvl_h(b3, b2); + __m128i c3 = __lsx_vilvh_h(b3, b2); + + __m128 d0 = (__m128)(__lsx_vilvl_w(c2, c0)); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 0], 0), __lsx_vpickve2gr_d(d0, 0), 0), (__m128i *)&Output[OutputStride * 0], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 1], 0), __lsx_vpickve2gr_d(d0, 1), 0), (__m128i *)&Output[OutputStride * 1], 0); + + __m128 d1 = (__m128)(__lsx_vilvh_w(c2, c0)); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 2], 0), __lsx_vpickve2gr_d(d1, 0), 0), (__m128i *)&Output[OutputStride * 2], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 3], 0), __lsx_vpickve2gr_d(d1, 1), 0), (__m128i *)&Output[OutputStride * 3], 0); + + __m128 d2 = (__m128)(__lsx_vilvl_w(c3, c1)); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 4], 0), __lsx_vpickve2gr_d(d2, 0), 0), (__m128i *)&Output[OutputStride * 4], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 5], 0), __lsx_vpickve2gr_d(d2, 1), 0), (__m128i *)&Output[OutputStride * 5], 0); + + __m128 d3 = (__m128)(__lsx_vilvh_w(c3, c1)); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 6], 0), __lsx_vpickve2gr_d(d3, 0), 0), (__m128i *)&Output[OutputStride * 6], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 7], 0), __lsx_vpickve2gr_d(d3, 1), 0), (__m128i *)&Output[OutputStride * 7], 0); +} + #endif template @@ -472,7 +587,8 @@ Return Value: uint32_t* d = Output; size_t m = M; -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER) || \ + defined(MLAS_LSX_INTRINSICS) while (m >= 4) { @@ -597,7 +713,7 @@ Return Value: uint16_t* d = Output; size_t m = M; -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) while (m >= 4) { @@ -734,7 +850,7 @@ Return Value: uint8_t* d = Output; size_t m = M; -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) while (m >= 8) { diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format new file mode 100644 index 0000000000000..84b876706161d --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format @@ -0,0 +1,7 @@ +Language: Cpp +BasedOnStyle: Google +DerivePointerAlignment: false +ColumnLimit: 120 +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SortIncludes: false diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt new file mode 100644 index 0000000000000..5d9c5edf45a96 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt @@ -0,0 +1,33 @@ +cmake_minimum_required(VERSION 3.5) + +project(jblas LANGUAGES CXX VERSION 0.1.0) + +file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp) +file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) + +add_library(${PROJECT_NAME} INTERFACE) +add_library(${PROJECT_NAME}::${PROJECT_NAME} ALIAS ${PROJECT_NAME}) + +target_include_directories( + ${PROJECT_NAME} INTERFACE + "$" + "$" +) + +if(WIN32) + target_compile_definitions(${PROJECT_NAME} INTERFACE _CRT_SECURE_NO_WARNINGS NOMINMAX) + target_compile_options(${PROJECT_NAME} INTERFACE /wd4068 /wd4849 /wd6262 /wd4702 /wd4100) + #4068 ignore unroll and GCC flags + #4849 ignore collapse + #6262 ignore stack too large + #4702 unreachable code(false warning on constexpr condition) + #4100 unreferenced formal parameter + + target_link_options(${PROJECT_NAME} INTERFACE /STACK:3145728) #Stack requires up to L2 cache size +endif(WIN32) + + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17) diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h new file mode 100644 index 0000000000000..143adb771760b --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h @@ -0,0 +1,303 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include + +#include +#include +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" + +#define OFFSET(field) offsetof(params, field) + +namespace jblas { + +namespace xbyak { +class JitBase : protected Xbyak::CodeGenerator { + protected: + JitBase(size_t size = 16 * 1024) : CodeGenerator(size) {} + + void load32(const Xbyak::Reg64& reg, const Xbyak::Address& addr) { + xor_(reg, reg); + mov(reg.cvt32(), addr); + } + + void vreg_push(const Xbyak::Reg64& baseaddr) { +#ifdef _WIN32 + for (int i = 0; i < 10; i++) { + movaps(xword[baseaddr + i * 16], Xbyak::Xmm(6 + i)); + } +#endif + } + + void vreg_pop(const Xbyak::Reg64& baseaddr) { +#ifdef _WIN32 + for (int i = 0; i < 10; i++) { + movaps(Xbyak::Xmm(6 + i), xword[baseaddr + i * 16]); + } +#endif + } + + void padto_le(const Xbyak::Reg64& _src, int padding) { + // _src=_src/padding*padding + if (padding == 1) { + return; + } + for (int i = 1; i < 16; i++) { + if ((1 << i) == padding) { + shr(_src, i); + shl(_src, i); + return; + } + } + assert(0); + } + + void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total, + const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { + inLocalLabel(); + lea(_tmp, _total); + sub(_tmp, _pos); + cmp(_tmp, N); + jb(".maskflag"); + cmp(_tmp, 0); + jl(".zeroflag"); + uint64_t allmask = (static_cast(1) << N) - 1; + if (N == 64) { + allmask = static_cast(-1); + } + mov(_tmp, allmask); + kmovq(_msk, _tmp); + jmp(".maskend"); + L(".maskflag"); + mov(_tmp1, 1); + shlx(_tmp1, _tmp1, _tmp); + sub(_tmp1, 1); + kmovq(_msk, _tmp1); + jmp(".maskend"); + L(".zeroflag"); + mov(_tmp1, 0); + kmovq(_msk, _tmp1); + L(".maskend"); + outLocalLabel(); + } + void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Reg64& _total, + const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { + generate_Nbitsmask(_msk, _pos, ptr[_total], _tmp, _tmp1, N); + } +}; + +class JitAvx : protected JitBase { + protected: + static int constexpr VBits = 256; + static int constexpr VecBytes = VBits / 8; + static int constexpr RegCount = 16; + typedef Xbyak::Ymm vreg_t; +}; + +class JitAvx2 : protected JitAvx { + protected: + static int constexpr VBits = 256; + typedef Xbyak::Ymm vreg_t; + void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); } + + void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) { + vpmovzxwd(dst, addr); + vpslld(dst, dst, 16); + } +}; + +class JitAvx512f : protected JitAvx2 { + protected: + static int constexpr VBits = 512; + static int constexpr VecBytes = VBits / 8; + static int constexpr RegCount = 32; + typedef Xbyak::Zmm vreg_t; + + void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); } + + void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) { + vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]); + vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]); + vshuff32x4(src_2regs[0], tmp_2reg[0], tmp_2reg[1], 0 | (1 << 2) | (0 << 4) | (1 << 6)); + vshuff32x4(src_2regs[0], src_2regs[0], src_2regs[0], 0 | (2 << 2) | (1 << 4) | (3 << 6)); + vshuff32x4(src_2regs[1], tmp_2reg[0], tmp_2reg[1], 2 | (3 << 2) | (2 << 4) | (3 << 6)); + vshuff32x4(src_2regs[1], src_2regs[1], src_2regs[1], 0 | (2 << 2) | (1 << 4) | (3 << 6)); + } + + void transpose16x16_4B(Xbyak::Zmm* src, Xbyak::Zmm* tmp, const int N = 16) { + for (int i = 0; i < 8; ++i) { + vpunpckldq(tmp[2 * i + 0], src[2 * i], src[2 * i + 1]); + vpunpckhdq(tmp[2 * i + 1], src[2 * i], src[2 * i + 1]); + } + + for (int i = 0; i < 4; ++i) { + vpunpcklqdq(src[4 * i + 0], tmp[4 * i + 0], tmp[4 * i + 2]); + vpunpckhqdq(src[4 * i + 1], tmp[4 * i + 0], tmp[4 * i + 2]); + vpunpcklqdq(src[4 * i + 2], tmp[4 * i + 1], tmp[4 * i + 3]); + vpunpckhqdq(src[4 * i + 3], tmp[4 * i + 1], tmp[4 * i + 3]); + } + + for (int i = 0; i < 2; ++i) { + vshufi32x4(tmp[8 * i + 0], src[8 * i + 0], src[8 * i + 4], 0x88); + vshufi32x4(tmp[8 * i + 1], src[8 * i + 1], src[8 * i + 5], 0x88); + vshufi32x4(tmp[8 * i + 2], src[8 * i + 2], src[8 * i + 6], 0x88); + vshufi32x4(tmp[8 * i + 3], src[8 * i + 3], src[8 * i + 7], 0x88); + vshufi32x4(tmp[8 * i + 4], src[8 * i + 0], src[8 * i + 4], 0xdd); + vshufi32x4(tmp[8 * i + 5], src[8 * i + 1], src[8 * i + 5], 0xdd); + vshufi32x4(tmp[8 * i + 6], src[8 * i + 2], src[8 * i + 6], 0xdd); + vshufi32x4(tmp[8 * i + 7], src[8 * i + 3], src[8 * i + 7], 0xdd); + } + + // last step and move out + for (int i = 0; i < N; ++i) { + vshufi32x4(src[i], tmp[i % 8], tmp[8 + i % 8], i < 8 ? 0x88 : 0xdd); + } + } + + void interleave_4rows_6regs(Xbyak::Zmm* src_4regs, Xbyak::Zmm* tmp_regs, const Xbyak::Opmask* masks) { + vpunpcklbw(tmp_regs[0], src_4regs[0], src_4regs[1]); + vpunpckhbw(tmp_regs[1], src_4regs[0], src_4regs[1]); + vpunpcklbw(tmp_regs[2], src_4regs[2], src_4regs[3]); + vpunpckhbw(tmp_regs[3], src_4regs[2], src_4regs[3]); + + vpunpcklwd(tmp_regs[4], tmp_regs[0], tmp_regs[2]); + vpunpckhwd(tmp_regs[5], tmp_regs[0], tmp_regs[2]); + vpunpcklwd(tmp_regs[0], tmp_regs[1], tmp_regs[3]); + vpunpckhwd(tmp_regs[2], tmp_regs[1], tmp_regs[3]); + vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (4 << 4) | 4); + vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (4 << 4) | 4); + vmovups(src_4regs[0], tmp_regs[1]); + vshuff32x4(src_4regs[0] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); + vmovups(src_4regs[1], tmp_regs[3]); + vshuff32x4(src_4regs[1] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); + vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (14 << 4) | 14); + vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (14 << 4) | 14); + vmovups(src_4regs[2], tmp_regs[1]); + vshuff32x4(src_4regs[2] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); + vmovups(src_4regs[3], tmp_regs[3]); + vshuff32x4(src_4regs[3] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); + } + + void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { + vpsrld(_fp32, _fp32, 16); + vpmovdw(_bf16, _fp32); + } + + void loadbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Address& addr) { + vpmovzxwd(dst, addr); + vpslld(dst, dst, 16); + } + + void broadcastbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Reg64& tmp, const Xbyak::Address& addr) { + mov(tmp.cvt16(), addr); + shl(tmp.cvt32(), 16); + vpbroadcastd(dst, tmp.cvt32()); + } + + void store_fp32_bf16(const Xbyak::Zmm& _fp32, const Xbyak::Address& _add) { + auto bf16 = Xbyak::Ymm(_fp32.getIdx()); + cvt_fp32_bf16(bf16, _fp32); + vmovups(_add, bf16); + } +}; + +class JitAvx512_bf16 : protected JitAvx512f {}; + +class JitAvx512_fp16 : protected JitAvx512f {}; + +class JitAvx512vnni : protected JitAvx512f { + protected: + void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { + vpdpbusds(x1, x2, op, Xbyak::EvexEncoding); + } +}; + +class JitAvxvnni : protected JitAvx2 { + protected: + void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { + vpdpbusds(x1, x2, op, Xbyak::VexEncoding); + } +}; + +class JitAmxtile : protected JitAvx512f { + public: + struct alignas(64) tileconfig_t { + uint8_t palette_id; + uint8_t reserved[15]; + uint16_t colb[16]; + uint8_t rows[16]; + }; + static int constexpr TileCount = 8; + + typedef long long (*configure_t)(void*); + + static void generate_config(Xbyak::CodeGenerator* g) { + Xbyak::util::StackFrame st(g, 1, 0, 0); + auto& parambase = st.p[0]; + g->ldtilecfg(g->ptr[parambase]); + } + + static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, + int CNum) { + // Filling tile configure structure. Could be done offline. + tc.palette_id = 1; + // Configure C tiles + int t = 0; + for (; t < CNum; ++t) { + tc.rows[t] = static_cast(TILE_M); + tc.colb[t] = static_cast(TILE_N * 4); + } + // Configure A tiles + for (; t < CNum + ANum; ++t) { + tc.rows[t] = static_cast(TILE_M); + tc.colb[t] = static_cast(TILE_K * elesize); + } + // Configure B tile. B effectively has 64 rows and 16 columns. + int kpack = 4 / elesize; + for (; t < CNum + ANum + BNum; ++t) { + tc.rows[t] = static_cast(TILE_K / kpack); + tc.colb[t] = static_cast(TILE_N * 4); + } + } +}; + +class JitAmxbf16 : protected JitAmxtile { + protected: + void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { vcvtneps2bf16(_bf16, _fp32); } +}; + +class JitAmxint8 : protected JitAmxtile { + protected: + template + void _tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3); +}; +template <> +inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { + tdpbssd(x1, x2, x3); +} +template <> +inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { + tdpbsud(x1, x2, x3); +} +template <> +inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { + tdpbusd(x1, x2, x3); +} +template <> +inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { + tdpbuud(x1, x2, x3); +} +} // namespace xbyak +} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h new file mode 100644 index 0000000000000..8ecf3535c17f4 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h @@ -0,0 +1,96 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +enum JBLAS_CODE { + JblasSuccess = 0, + JblasInvalidParam = 1, + JblasInvalidISA = 2, + JblasRuntimeError = 4, + JblasNotSupport = 8, +}; +enum JBLAS_ISA : uint32_t { + JblasNoSIMD = 0, + JblasAVX, + JblasAVX2, + JblasAVX_VNNI, + JblasAVX512F, + JblasAVX512_VNNI, + JblasAMX_BF16, + JblasAMX_INT8, + JblasAVX512_FP16, + JblasAVX512_BF16, +}; +enum class JBLAS_DTYPE : uint32_t { + EleBitsMask = 0xff, + EleBitsUndef = 0, + EleBits4 = 4, + EleBits8 = 8, + EleBits16 = 16, + EleBits32 = 32, + EleBits64 = 64, + TypeMask = 0xff00, + TypeFloat = 0 << 8, + TypeInt = 1 << 8, + SubTypeMask = 0xff0000, + SubType0 = 0 << 16, + SubType1 = 1 << 16, + SubType2 = 2 << 16, + F64 = EleBits64 | TypeFloat, + F32 = EleBits32 | TypeFloat, + F16 = EleBits16 | TypeFloat, + BF16 = EleBits16 | TypeFloat | SubType1, + F8_E4M3 = EleBits8 | TypeFloat, + F8_E5M2 = EleBits8 | TypeFloat | SubType1, + F8_E3M4 = EleBits8 | TypeFloat | SubType2, + S8 = EleBits8 | TypeInt, + U8 = EleBits8 | TypeInt | SubType1, + S4_CLIP = EleBits4 | TypeInt, + S4_FULLRANGE = EleBits4 | TypeInt | SubType1, + F4_E2M1 = EleBits4 | TypeFloat, + F4_BNB = EleBits4 | TypeFloat | SubType1, + F4_NF4 = EleBits4 | TypeFloat | SubType2, + S32 = EleBits32 | TypeInt, + U32 = EleBits32 | TypeInt | SubType1, +}; + +enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 }; +enum JBLAS_TRANSPOSE { + JblasNoTrans = 111, + JblasTrans = 112, + JblasConjTrans = 113, +}; +enum JBLAS_ELTWISEOP { + GELU, + SWISH, + TANH, + EXP, + LOW_PRECISION_EXP, + RELU, + LINEAR, +}; + +enum class JBLAS_PROLOGUEB_IDS : uint32_t { + Undef = (uint32_t)-1, + Begin = 0, + NormalBegin = Begin, + WeightPack = NormalBegin, + NormalEnd, + KBlockBegin = NormalEnd, + WeightKBlockS8 = KBlockBegin, + WeightKBlockS4, + WeightKBlockF4, + KBlockEnd, + End, +}; diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h new file mode 100644 index 0000000000000..5cac1080bc610 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h @@ -0,0 +1,277 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "jit_blas.h" +#include "xbyak/xbyak_util.h" + +namespace jblas { + +namespace device { + +struct X64_ISA { + int64_t MMX : 1; // 0 + int64_t SSE : 1; // 1 + int64_t SSE2 : 1; // 2 + int64_t SSE3 : 1; // 3 + int64_t SSSE3 : 1; // 4 + int64_t SSE41 : 1; // 5 + int64_t SSE42 : 1; // 6 + int64_t AVX : 1; // 7 + int64_t F16C : 1; // 8 + int64_t FMA : 1; // 9 + int64_t AVX2 : 1; // 10 + int64_t AVX_VNNI : 1; // 11 + int64_t AVX_VNNI_INT8 : 1; // 12 + int64_t AVX_NE_CONVERT : 1; // 13 + int64_t AVX_IFMA : 1; // 14 + int64_t AVX512F : 1; // 15 + int64_t AVX512BW : 1; // 16 + int64_t AVX512CD : 1; // 17 + int64_t AVX512DQ : 1; // 18 + int64_t AVX512ER : 1; // 19 + int64_t AVX512IFMA52 : 1; // 20 + int64_t AVX512PF : 1; // 21 + int64_t AVX512VL : 1; // 22 + int64_t AVX512VPOPCNTDQ : 1; // 23 + int64_t AVX512_4FMAPS : 1; // 24 + int64_t AVX512_4VNNIW : 1; // 25 + int64_t AVX512_BF16 : 1; // 26 + int64_t AVX512_BITALG : 1; // 27 + int64_t AVX512_VBMI : 1; // 28 + int64_t AVX512_VBMI2 : 1; // 29 + int64_t AVX512_VNNI : 1; // 30 + int64_t AVX512_VP2INTERSECT : 1; // 31 + int64_t AVX512_FP16 : 1; // 32 + int64_t AMX_TILE : 1; // 33 + int64_t AMX_BF16 : 1; // 34 + int64_t AMX_INT8 : 1; // 35 + int64_t AMX_FP16 : 1; // 36 + int64_t AMX_COMPLEX : 1; // 37 + int64_t reserved : (64 - 38); +}; + +class AVX2_Default { + public: + static constexpr bool MMX = 1; + static constexpr bool SSE = 1; + static constexpr bool SSE2 = 1; + static constexpr bool SSE3 = 1; + static constexpr bool SSSE3 = 1; + static constexpr bool SSE41 = 1; + static constexpr bool SSE42 = 1; + static constexpr bool AVX = 1; + static constexpr bool F16C = 1; + static constexpr bool FMA = 1; + static constexpr bool AVX2 = 1; + static constexpr bool AVX_VNNI = 0; + static constexpr bool AVX_VNNI_INT8 = 0; + static constexpr bool AVX_NE_CONVERT = 0; + static constexpr bool AVX_IFMA = 0; + static constexpr bool AVX512F = 0; + static constexpr bool AVX512BW = 0; + static constexpr bool AVX512CD = 0; + static constexpr bool AVX512DQ = 0; + static constexpr bool AVX512ER = 0; + static constexpr bool AVX512IFMA52 = 0; + static constexpr bool AVX512PF = 0; + static constexpr bool AVX512VL = 0; + static constexpr bool AVX512VPOPCNTDQ = 0; + static constexpr bool AVX512_4FMAPS = 0; + static constexpr bool AVX512_4VNNIW = 0; + static constexpr bool AVX512_BF16 = 0; + static constexpr bool AVX512_BITALG = 0; + static constexpr bool AVX512_VBMI = 0; + static constexpr bool AVX512_VBMI2 = 0; + static constexpr bool AVX512_VNNI = 0; + static constexpr bool AVX512_VP2INTERSECT = 0; + static constexpr bool AVX512_FP16 = 0; + static constexpr bool AMX_TILE = 0; + static constexpr bool AMX_BF16 = 0; + static constexpr bool AMX_INT8 = 0; + static constexpr bool AMX_FP16 = 0; + static constexpr bool AMX_COMPLEX = 0; +}; + +class AVX512_VNNI_Default { + public: + static constexpr bool MMX = 1; + static constexpr bool SSE = 1; + static constexpr bool SSE2 = 1; + static constexpr bool SSE3 = 1; + static constexpr bool SSSE3 = 1; + static constexpr bool SSE41 = 1; + static constexpr bool SSE42 = 1; + static constexpr bool AVX = 1; + static constexpr bool F16C = 1; + static constexpr bool FMA = 1; + static constexpr bool AVX2 = 1; + static constexpr bool AVX_VNNI = 0; + static constexpr bool AVX_VNNI_INT8 = 0; + static constexpr bool AVX_NE_CONVERT = 0; + static constexpr bool AVX_IFMA = 0; + static constexpr bool AVX512F = 1; + static constexpr bool AVX512BW = 1; + static constexpr bool AVX512CD = 1; + static constexpr bool AVX512DQ = 1; + static constexpr bool AVX512ER = 0; + static constexpr bool AVX512IFMA52 = 0; + static constexpr bool AVX512PF = 0; + static constexpr bool AVX512VL = 1; + static constexpr bool AVX512VPOPCNTDQ = 0; + static constexpr bool AVX512_4FMAPS = 0; + static constexpr bool AVX512_4VNNIW = 0; + static constexpr bool AVX512_BF16 = 0; + static constexpr bool AVX512_BITALG = 0; + static constexpr bool AVX512_VBMI = 0; + static constexpr bool AVX512_VBMI2 = 0; + static constexpr bool AVX512_VNNI = 1; + static constexpr bool AVX512_VP2INTERSECT = 0; + static constexpr bool AVX512_FP16 = 0; + static constexpr bool AMX_TILE = 0; + static constexpr bool AMX_BF16 = 0; + static constexpr bool AMX_INT8 = 0; + static constexpr bool AMX_FP16 = 0; + static constexpr bool AMX_COMPLEX = 0; +}; + +class SapphireRapids { + public: + static constexpr bool MMX = 1; + static constexpr bool SSE = 1; + static constexpr bool SSE2 = 1; + static constexpr bool SSE3 = 1; + static constexpr bool SSSE3 = 1; + static constexpr bool SSE41 = 1; + static constexpr bool SSE42 = 1; + static constexpr bool AVX = 1; + static constexpr bool F16C = 1; + static constexpr bool FMA = 1; + static constexpr bool AVX2 = 1; + static constexpr bool AVX_VNNI = 0; + static constexpr bool AVX_VNNI_INT8 = 0; + static constexpr bool AVX_NE_CONVERT = 0; + static constexpr bool AVX_IFMA = 0; + static constexpr bool AVX512F = 1; + static constexpr bool AVX512BW = 1; + static constexpr bool AVX512CD = 1; + static constexpr bool AVX512DQ = 1; + static constexpr bool AVX512ER = 0; + static constexpr bool AVX512IFMA52 = 0; + static constexpr bool AVX512PF = 0; + static constexpr bool AVX512VL = 1; + static constexpr bool AVX512VPOPCNTDQ = 0; + static constexpr bool AVX512_4FMAPS = 0; + static constexpr bool AVX512_4VNNIW = 0; + static constexpr bool AVX512_BF16 = 0; + static constexpr bool AVX512_BITALG = 0; + static constexpr bool AVX512_VBMI = 0; + static constexpr bool AVX512_VBMI2 = 0; + static constexpr bool AVX512_VNNI = 1; + static constexpr bool AVX512_VP2INTERSECT = 0; + static constexpr bool AVX512_FP16 = 0; + static constexpr bool AMX_TILE = 1; + static constexpr bool AMX_BF16 = 1; + static constexpr bool AMX_INT8 = 1; + static constexpr bool AMX_FP16 = 0; + static constexpr bool AMX_COMPLEX = 0; +}; + +template +class isa_base { + public: + static bool constexpr avx = ISA_T >= JblasAVX; + static bool constexpr avx2 = ISA_T >= JblasAVX2; + static bool constexpr avx512f = ISA_T >= JblasAVX512F; + static bool constexpr avx512_vnni = ISA_T >= JblasAVX512_VNNI; + static bool constexpr avx512_fp16 = ISA_T >= JblasAVX512_FP16; + static bool constexpr amx_bf16 = ISA_T >= JblasAMX_BF16; + static bool constexpr amx_int8 = ISA_T >= JblasAMX_INT8; +}; + +class CpuDevice { + public: + inline void setThreads(int _nth) { + if (_nth <= 0) { + numthreads = numcores; + } else { + numthreads = std::min(numcores, _nth); + } + } + inline int getThreads() { return numthreads; } + inline int getCores() { return numcores; } + inline uint32_t getL2CacheSize() { return L2Cache; } + inline uint32_t getL1CacheSize() { return L1Cache; } + inline bool AVX() { return mHasAVX; } + inline bool AVX2() { return mHasAVX2; } + inline bool AVX_VNNI() { return mHasAVX_VNNI; } + inline bool AVX512F() { return mHasAVX512F; } + inline bool AVX512_VNNI() { return mHasAVX512_VNNI; } + inline bool AMX_INT8() { return mHasAMX_INT8; } + inline bool AMX_BF16() { return mHasAMX_BF16; } + inline bool AVX512_BF16() { return mHasAVX512_BF16; } + inline bool AVX512_FP16() { return mHasAVX512_FP16; } +#define ADD_FLAG(isa) mHas##isa = _cpu.has(_cpu.t##isa) + CpuDevice() { + static Xbyak::util::Cpu _cpu; + L1Cache = _cpu.getDataCacheSize(0); + L2Cache = _cpu.getDataCacheSize(1); + ADD_FLAG(AVX); + ADD_FLAG(AVX2); + ADD_FLAG(AVX512F); + ADD_FLAG(AVX512_VNNI); + ADD_FLAG(AVX_VNNI); + ADD_FLAG(AMX_BF16); + ADD_FLAG(AMX_INT8); + ADD_FLAG(AVX512_BF16); + ADD_FLAG(AVX512_FP16); + numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); + numthreads = numcores; + } + + static CpuDevice* getInstance() { + static CpuDevice instance; + return &instance; + } + + void print() { + printf( + "AVX:%d AVX2:%d AVX512F:%d AVX_VNNI:%d AVX512_VNNI:%d AMX_INT8:%d AMX_BF16:%d AVX512_BF16:%d AVX512_FP16:%d\n", + mHasAVX, mHasAVX2, mHasAVX512F, mHasAVX_VNNI, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512_BF16, + mHasAVX512_FP16); + } +#undef ADD_FLAG + + protected: + uint32_t L2Cache, L1Cache; + bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512_BF16, + mHasAVX512_FP16; + int numcores; + int numthreads; +}; + +#define GetCPUDevice() auto _cd = jblas::device::CpuDevice::getInstance(); + +class CpuBase { + public: + CpuBase() { + GetCPUDevice(); + mL2Cache = _cd->getL2CacheSize(); + mL1Cache = _cd->getL1CacheSize(); + mNumThreads = _cd->getThreads(); + } + size_t mL2Cache, mL1Cache; + int mNumThreads; +}; +} // namespace device +} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h new file mode 100644 index 0000000000000..ceb7a545092d8 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h @@ -0,0 +1,329 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include + +#include "jit_base.h" +#include "jit_blas.h" +#include "jit_blas_utils.h" +#include "kernel_wrapper.h" + +namespace jblas { +namespace epilogue { +namespace gemm { + +template +class AccumulatorWriteBack { + public: + using SType = _SRC_T; + using DType = _DST_T; + struct Param { + DType* C; + int ldc; + void* elt_const_v; + }; + + template + JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize, Eltops... ops) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + bool constexpr Valid = !std::is_same::value || std::is_same::value; + static_assert(Valid, "fp32 to bf16 conversion only."); + if constexpr (std::is_same::value) { + return kernel::wrapper::Memcpy2DFp32CvtBf16::template forward( + const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); + } else if constexpr (std::is_same, std::tuple>::value) { + return kernel::wrapper::Memcpy2DFp16CvtFp32::template forward( + const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); + } else if constexpr (sizeof(SType) == sizeof(DType)) { + return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep, + _param.ldc, _param.elt_const_v, ops...); + } else { + assert(false); + } + } +}; + +template +class CustomAccumulatorWriteBackWithEltop { + public: + struct Param { + _DST_T* C; + int ldc; + void* elt_const_v; + }; + JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) { + return kernel::wrapper::Memcpy2D::template forward1(cacheptr, cptr, M, N, cachestep, + _param.ldc, _param.elt_const_v); + } else { + assert(false); + } + } +}; +template +using AccumulatorWriteBackFp32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackInt32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackBf16 = AccumulatorWriteBack; +template +using AccumulatorWriteBackFp16 = AccumulatorWriteBack; +template +using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack; + +template +using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop; + +template +using AccumulatorWriteBackWithSwishFp32 = CustomAccumulatorWriteBackWithEltop; + +template +class AlphaBetaProcessFp32 { + public: + struct Param { + float *C, *D; + int ldc, ldd; + float alpha, beta; + }; + + JBLAS_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto DOffset = M_offset * _param.ldd + N_offset; + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + auto dptr = _param.D + DOffset; + return kernel::wrapper::AlphaBetaF32F32::template forward(_param.alpha, cacheptr, cachestep, _param.beta, + dptr, _param.ldd, cptr, _param.ldc, M, N); + } +}; + +template +class CompFp32BlockEpilogue { + public: + struct Param { + void* scales; + JBLAS_DTYPE scaledtype; + int ldsb; + int8_t* zps = nullptr; + float* reduce = nullptr; + int ldra; + }; + JBLAS_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, + const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, + size_t cachesize) { + auto ret = JblasNotSupport; + if (_param.scaledtype == JBLAS_DTYPE::F32) { + ret = kernel::wrapper::CompFp32BlockScale::template forward( + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, + cachestep, M, N); + assert(ret == JblasSuccess); + if (_param.zps != nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::forward_wei( + dstptr, cachestep, M, N, _param.zps + K_offset * _param.ldsb + N_offset, + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, _param.ldra, + _param.reduce + M_offset * _param.ldra + K_offset); + } + assert(ret == JblasSuccess); + return ret; + } else if (_param.scaledtype == JBLAS_DTYPE::BF16) { + ret = kernel::wrapper::CompFp32BlockScale::template forward( + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, + cachestep, M, N); + assert(_param.zps == nullptr); + assert(ret == JblasSuccess); + return ret; + } + return JblasNotSupport; + } +}; + +template +class DequantInt32ToFp32 { + public: + struct Param { + float* C; + int ldc; + int ldsa; + float* scalesA; + float* scalesB; + }; + JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + return kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, + _param.scalesA + M_offset * _param.ldsa, _param.ldsa, + _param.scalesB + N_offset); + } +}; + +template +class CompInt8BlockEpilogue { + public: + struct Param { + void* scalesB; + JBLAS_DTYPE scaleBdtype; + int ldsb; + float* scalesA; + int ldsa; + // optional if A asym + uint8_t* zpA = nullptr; + void* reduceB = nullptr; + JBLAS_DTYPE reduceBdtype = JBLAS_DTYPE::F32; + // optional if B asym + int8_t* zpB = nullptr; + float* reduceA = nullptr; + int K = 1; + }; + JBLAS_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, + const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, + size_t cachesize) { + JBLAS_CODE ret = JblasNotSupport; + float* scab = nullptr; + size_t ScaleBTmpSize = N * sizeof(float); + size_t ReduceBTmpSize = N * sizeof(float); + assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); + if (_param.scaleBdtype == JBLAS_DTYPE::BF16) { + auto scache = reinterpret_cast(tmpcache); + ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( + reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, + false); + assert(ret == JblasSuccess); + scab = scache; + } else if (_param.scaleBdtype == JBLAS_DTYPE::F32) { + scab = reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb; + } + float* redb = nullptr; + if (_param.reduceB) { + if (_param.reduceBdtype == JBLAS_DTYPE::BF16) { + auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize); + ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( + reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, + false); + assert(ret == JblasSuccess); + redb = rcache; + } else if (_param.reduceBdtype == JBLAS_DTYPE::F32) { + redb = reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb; + } + } + ret = kernel::wrapper::DequanS32Fp32::template forward( + srcptr, cachestep, reinterpret_cast(const_cast(srcptr)), cachestep, M, N, + _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab); + assert(ret == JblasSuccess); + ret = kernel::wrapper::AccumulateFp32::template forward(reinterpret_cast(srcptr), cachestep, + dstptr, cachestep, M, N); + assert(ret == JblasSuccess); + + if (_param.zpA == nullptr) { + if (_param.zpB == nullptr) { + return ret; + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( + dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa, + _param.reduceA + M_offset * _param.ldsa + K_offset); + } + } else { + if (_param.zpB == nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( + dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, + _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb); + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( + dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, + _param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab, + _param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb); + } + } + return ret; + } +}; + +template +class ZpDequantInt32ToFp32 { + public: + struct Param { + // necessary + float* C; + int ldc; + int ldsa; + float* scalesA; + float* scalesB; + // optional if A asym + uint8_t* zpA = nullptr; + float* reduceB = nullptr; + // optional if B asym + int8_t* zpB = nullptr; + float* reduceA = nullptr; + int K = 1; + }; + JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + auto ret = kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, + _param.scalesA + M_offset * _param.ldsa, + _param.ldsa, _param.scalesB + N_offset); + if (ret != JblasSuccess) { + return ret; + } + if (_param.zpA == nullptr && _param.zpB == nullptr) { + return ret; + } else if (_param.zpA != nullptr && _param.zpB == nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( + cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa, + _param.ldsa, _param.reduceB + N_offset); + } else if (_param.zpA == nullptr && _param.zpB != nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( + cptr, _param.ldc, M, N, _param.zpB + N_offset, _param.scalesB + N_offset, _param.ldsa, + _param.reduceA + M_offset * _param.ldsa); + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( + cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.zpB + N_offset, + _param.scalesA + M_offset * _param.ldsa, _param.scalesB + N_offset, _param.ldsa, _param.K, + _param.reduceA + M_offset * _param.ldsa, _param.reduceB + N_offset); + } + return ret; + } +}; + +template +class AlphaBetaProcessS32U8 { + public: + struct Param { + uint8_t* C; + int ldc; + float alpha; + float scaleAcc, scaleC; + int zpC; + }; + + JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + return kernel::wrapper::QuanOutS32U32::template forward(_param.alpha, cacheptr, cachestep, cptr, _param.ldc, + M, N, _param.scaleAcc, _param.scaleC, _param.zpC); + } +}; + +} // namespace gemm +} // namespace epilogue +} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h new file mode 100644 index 0000000000000..364da9223940f --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h @@ -0,0 +1,2699 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include + +#include "jit_blas_utils.h" +#include "jit_base.h" + +namespace jblas { +namespace gemm { +enum class CompType : uint32_t { + COMP_FP32 = 0, + COMP_BF16_FP32 = 1, + COMP_FP16_FP16 = 2, + COMP_INT_START = 3, + COMP_INT8_US_INT32 = COMP_INT_START, + COMP_INT8_UU_INT32 = 4, + COMP_INT8_SS_INT32 = 5, + COMP_INT8_SU_INT32 = 6, + COMP_INT16_SS_INT32 = 7, + COMP_INT8_US_FP32 = 8, + COMP_INT8_UU_FP32 = 9, + COMP_INT8_SS_FP32 = 10, + COMP_INT8_SU_FP32 = 11, +}; + +class CoreAttr { + public: + // INT32=LSB|**8bits:NTile**||**8bits:PackRow**||**8bits:CompType**||**8bits:Reserve**| + static uint32_t constexpr NTILE_MASK = 0xff, NTILE_SHIFT = 0, PACKROW_MASK = 0xff00, PACKROW_SHIFT = 8, + COMP_MASK = 0xff0000, COMP_SHIFT = 16, ISA_MASK = 0xff000000, ISA_SHIFT = 24; + + static inline uint32_t get_mask_val(uint32_t raw, uint32_t mask, uint32_t shift) { return (raw & mask) >> shift; } + static constexpr uint32_t make_core_id(uint32_t NTile, uint32_t PackRow, uint32_t CompType, uint32_t ISA) { + return (NTile << NTILE_SHIFT) | (PackRow << PACKROW_SHIFT) | (CompType << COMP_SHIFT) | (ISA << ISA_SHIFT); + } + + static void parse_id(uint32_t id, uint32_t* vals) { + vals[0] = get_mask_val(id, NTILE_MASK, NTILE_SHIFT); + vals[1] = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); + vals[2] = get_mask_val(id, COMP_MASK, COMP_SHIFT); + vals[3] = get_mask_val(id, ISA_MASK, ISA_SHIFT); + } + + static const char* to_str(uint32_t id) { + static char tmp[128]; + uint32_t vals[4]; + parse_id(id, vals); + sprintf(tmp, "N%d_PACK%d_COMP%d_ISA%d", vals[0], vals[1], vals[2], vals[3]); + return tmp; + } + + static inline size_t get_bsize(uint32_t id) { + auto packrow = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); + return size_t(4 / packrow); + } +}; + +namespace code { + +template +class Avx2N8P1 : protected jblas::xbyak::JitAvx2 { + public: + static int constexpr RegLen = 8, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX2; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { + public: + static int constexpr RegLen = 16, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512fp16N32P1 : protected jblas::xbyak::JitAvx512_fp16 { + public: + static int constexpr RegLen = 32, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_FP16; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP16_FP16; + typedef utils::fp16 AType; + typedef utils::fp16 BType; + typedef utils::fp16 CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastw(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastw(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512bf16N16P2 : protected jblas::xbyak::JitAvx512_bf16 { + public: + static int constexpr RegLen = 16, PackRow = 2; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 2; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_BF16; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; + typedef utils::bf16 AType; + typedef utils::bf16 BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; + typedef uint8_t AType; + typedef int8_t BType; + typedef int32_t CType; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + + protected: + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _kunroll) { + for (int kk = 0; kk < _kunroll; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { + public: + static int constexpr RegLen = 8, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX_VNNI; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; + typedef uint8_t AType; + typedef int8_t BType; + typedef int32_t CType; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + protected: + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _kunroll) { + for (int kk = 0; kk < _kunroll; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Amxbf16N16P2 : protected jblas::xbyak::JitAmxbf16 { + public: + static int constexpr RegLen = 16, PackRow = 2; + static_assert(_NTILE % RegLen == 0); + static_assert(_MTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; + static_assert(NRegs * MRegs + 2 <= TileCount); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_BF16; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; + typedef utils::bf16 AType; + typedef utils::bf16 BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + void* workspace; + }; + typedef long long (*func_t)(params*); + + int TmpRegCount = RegCount; + int TmpReg = 0; + int CTileCount = 0, ATileCount = 0, BTileCount = 0; + int CTile = 0, ATile = 0, BTile = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CTileCount = NRegs * MRegs; + auto tile_re = TileCount - CTileCount; + if (tile_re - 1 >= NRegs) { + BTileCount = NRegs; + ATileCount = tile_re - BTileCount; + } else if (tile_re - 1 >= MRegs) { + ATileCount = MRegs; + BTileCount = tile_re - ATileCount; + } else { + ATileCount = 1; + BTileCount = tile_re - ATileCount; + } + CTile = 0; + ATile = CTile + CTileCount; + BTile = ATile + ATileCount; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int kunrll) { + auto& reg_Bstride = reg_tmp1; + mov(reg_Bstride, NTILE * 4); + int mtiles = _mtile / RegLen; + + for (int kk = 0; kk < kunrll; kk++) { + auto& reg_Atmp = reg_tmp2; + if (mtiles == 1) { + reg_Atmp = reg_matAptr; + } else { + mov(reg_Atmp, reg_matAptr); + } + if (BTileCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + } + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } else { + if (ATileCount == mtiles) { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < mtiles; mm++) { + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); + } + } + } else { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < CTileCount; i++) { + tilezero(Xbyak::Tmm(CTile + i)); + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + int mtnum = _mtile / 16; + for (int mm = 0; mm < mtnum; mm++) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); + } + if (mm != mtnum - 1) { + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + } + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_tmp, dword[parambase + OFFSET(workspace)]); + mov(reg_tmp1, NTILE * 4); + for (int mm = 0; mm < MRegs; mm++) { + for (int i = 0; i < NRegs; i++) { + tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); + } + } + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + int zunroll = TmpRegCount / NRegs; + for (int i = 0; i < _mtile; i += zunroll) { + int m_re = utils::remainsize(i, _mtile, zunroll); + for (int im = 0; im < m_re; im++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + } + outLocalLabel(); + } +}; + +template +class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static_assert(_MTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; + static_assert(NRegs * MRegs + 2 <= TileCount); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_INT8; + static uint32_t constexpr COMPUTE = + (uint32_t)(std::is_same_v + ? std::is_same_v ? CompType::COMP_INT8_SS_INT32 : CompType::COMP_INT8_SU_INT32 + : std::is_same_v ? CompType::COMP_INT8_US_INT32 + : CompType::COMP_INT8_UU_INT32); + using AType = AT; + using BType = BT; + typedef int32_t CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + void* workspace; + }; + typedef long long (*func_t)(params*); + + int TmpRegCount = RegCount; + int TmpReg = 0; + int CTileCount = 0, ATileCount = 0, BTileCount = 0; + int CTile = 0, ATile = 0, BTile = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CTileCount = NRegs * MRegs; + auto tile_re = TileCount - CTileCount; + if (tile_re - 1 >= NRegs) { + BTileCount = NRegs; + ATileCount = tile_re - BTileCount; + } else if (tile_re - 1 >= MRegs) { + ATileCount = MRegs; + BTileCount = tile_re - ATileCount; + } else { + ATileCount = 1; + BTileCount = tile_re - ATileCount; + } + CTile = 0; + ATile = CTile + CTileCount; + BTile = ATile + ATileCount; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int kunrll) { + auto& reg_Bstride = reg_tmp1; + mov(reg_Bstride, NTILE * 4); + int mtiles = _mtile / RegLen; + + for (int kk = 0; kk < kunrll; kk++) { + auto& reg_Atmp = reg_tmp2; + if (mtiles == 1) { + reg_Atmp = reg_matAptr; + } else { + mov(reg_Atmp, reg_matAptr); + } + if (BTileCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + } + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } else { + if (ATileCount == mtiles) { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < mtiles; mm++) { + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); + } + } + } else { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < CTileCount; i++) { + tilezero(Xbyak::Tmm(CTile + i)); + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + int mtnum = _mtile / 16; + for (int mm = 0; mm < mtnum; mm++) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); + } + if (mm != mtnum - 1) { + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + } + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_tmp, dword[parambase + OFFSET(workspace)]); + mov(reg_tmp1, NTILE * 4); + for (int mm = 0; mm < MRegs; mm++) { + for (int i = 0; i < NRegs; i++) { + tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); + } + } + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + int zunroll = TmpRegCount / NRegs; + for (int i = 0; i < _mtile; i += zunroll) { + int m_re = utils::remainsize(i, _mtile, zunroll); + for (int im = 0; im < m_re; im++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + } + outLocalLabel(); + } +}; +template +using Amxint8N16P4US = Amxint8N16P4; + +template +using Amxint8N16P4SS = Amxint8N16P4; + +class AmxConfigure : protected jblas::xbyak::JitAmxtile { + public: + typedef long long (*func_t)(tileconfig_t*); + + static void configure(int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, int CNum) { + static AmxConfigure code; + tileconfig_t cfg; + std::memset(&cfg, 0, sizeof(cfg)); + configure_tiles(cfg, TILE_M, TILE_N, TILE_K, elesize, ANum, BNum, CNum); + code.mKernel(&cfg); + } + + protected: + AmxConfigure() { + generate_config(this); + mKernel = getCode(); + } + + func_t mKernel = nullptr; +}; + +namespace kblock { +// optimize for kblock gemm, each block size in k dimension has dequant operation +// all accumulators use fp32 dtype. +template +class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { + public: + static int constexpr RegLen = 16, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1 - NRegs) / (NRegs * 2) : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_FP32; + typedef uint8_t AType; + typedef int8_t BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + uint8_t* zpA; + float* scaleA; + int ldsa; + float* scaleB; + float* reduceB; + int ldsb; + int k; + int n; + int kblock; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_iterkb; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_tmp4; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = NRegs; + CReg = 0; + CF32Reg = CReg + CRegCount; + BReg = CF32Reg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg < RegCount); + TmpRegCount = RegCount - TmpReg; + assert(TmpRegCount >= 1); + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_iterkb = st.t[12]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_tmp4 = st.t[11]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + xor_(reg_iterkb, reg_iterkb); + L(".kloop"); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); + } + } + xor_(reg_tmp2, reg_tmp2); + load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); + mov(reg_tmp, reg_tmp3); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kbloop", T_NEAR); + L(".unkbloop"); + generate_fma(_mtile, KUNROLL, reg_tmp1); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_tmp2, KUNROLL * KTILE); + cmp(reg_tmp2, reg_tmp); + jb(".unkbloop"); + cmp(reg_tmp, reg_tmp3); + jge(".kend", T_NEAR); + L(".kbloop"); + generate_fma(_mtile, 1, reg_tmp1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_tmp2, 1 * KTILE); + cmp(reg_tmp2, reg_tmp3); + jb(".kbloop"); + L(".kend"); + add(reg_iterk, reg_tmp2); + generate_f32_accumulate(_mtile); + generate_zp_correction(_mtile); + inc(reg_iterkb); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) { + for (int kk = 0; kk < _ktile; kk++) { + lea(tmp, ptr[reg_matAptr + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void generate_f32_accumulate(int _mtile) { + load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + + mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); + load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); + for (int i = 0; i < NRegs; i++) { + vmovups(Xbyak::Zmm(BReg + i), ptr[reg_tmp2 + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(Xbyak::Zmm(TmpReg), ptr[reg_tmp]); + lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); + for (int i = 0; i < NRegs; i++) { + vcvtdq2ps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(TmpReg), Xbyak::Zmm(BReg + i)); + vmulps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(AReg)); + vaddps(Xbyak::Zmm(CF32Reg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); + } + } + } + + void generate_zp_correction(int _mtile) { + load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp1, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + auto& reg_redB = reg_tmp2; + + mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); + auto& reg_zpA = reg_tmp; + + mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); + auto& reg_scaleA = reg_tmp1; + + load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); + auto& reg_ldsa = reg_tmp3; + for (int i = 0; i < NRegs; i++) { + vmovups(Xbyak::Zmm(BReg + i), ptr[reg_redB + i * VecBytes]); + } + + for (int i = 0; i < _mtile; i++) { + vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); + vpmovzxbd(Xbyak::Zmm(AReg), Xbyak::Xmm(AReg)); + vcvtdq2ps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg)); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), zword_b[reg_scaleA]); + for (int j = 0; j < NRegs; j++) { + vmulps(Xbyak::Zmm(CReg + j), Xbyak::Zmm(AReg), Xbyak::Zmm(BReg + j)); + vsubps(Xbyak::Zmm(CF32Reg + i * NRegs + j), Xbyak::Zmm(CReg + j)); + } + lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); + lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); + } + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +} // namespace kblock +} // namespace code +template