diff --git a/.github/workflows/react_native.yml b/.github/workflows/react_native.yml index f827f3bc95456..343186b1aec8c 100644 --- a/.github/workflows/react_native.yml +++ b/.github/workflows/react_native.yml @@ -53,14 +53,11 @@ jobs: cp tools/ci_build/github/js/react_native_e2e_full_aar_build_settings.json ${{ runner.temp }}/.build_settings/build_settings.json python3 -m pip install --user -r ${{ github.workspace }}/tools/ci_build/requirements/pybind/requirements.txt - - python3 ${{ github.workspace }}/tools/ci_build/github/android/build_aar_package.py --build_dir ${{ runner.temp }} --config Release --android_sdk_path $ANDROID_SDK_ROOT --android_ndk_path $ANDROID_NDK_ROOT ${{ runner.temp }}/.build_settings/build_settings.json + + python3 ${{ github.workspace }}/tools/ci_build/github/android/build_aar_package.py --build_dir ${{ runner.temp }} --config Release --android_sdk_path $ANDROID_SDK_ROOT --android_ndk_path $ANDROID_NDK_ROOT ${{ runner.temp }}/.build_settings/build_settings.json # Copy the built artifacts to give folder for publishing - BASE_PATH=${{ runner.temp }}/aar_out/Release/com/microsoft/onnxruntime/onnxruntime-android/${OnnxRuntimeVersion} - cp ${BASE_PATH}/*.jar ${{ runner.temp }}/artifacts - cp ${BASE_PATH}/*.aar ${{ runner.temp }}/artifacts - cp ${BASE_PATH}/*.pom ${{ runner.temp }}/artifacts + cp -r ${{ runner.temp }}/aar_out/Release/com ${{ runner.temp }}/artifacts - name: Upload Android AAR Artifact uses: actions/upload-artifact@v5 @@ -109,10 +106,8 @@ jobs: - name: Copy AAR to React Native and E2E directories run: | - mkdir -p ${{ github.workspace }}/js/react_native/android/libs - cp ${{ runner.temp }}/android-full-aar/*.aar ${{ github.workspace }}/js/react_native/android/libs mkdir -p ${{ github.workspace }}/js/react_native/e2e/android/app/libs - cp ${{ runner.temp }}/android-full-aar/*.aar ${{ github.workspace }}/js/react_native/e2e/android/app/libs + cp -r ${{ runner.temp }}/android-full-aar/com ${{ github.workspace }}/js/react_native/e2e/android/app/libs - name: Install dependencies and bootstrap run: | @@ -141,10 +136,6 @@ jobs: with: ndk-version: 28.0.13004108 - - name: Run React Native Android Instrumented Tests - run: ./gradlew connectedDebugAndroidTest --stacktrace - working-directory: ${{ github.workspace }}/js/react_native/android - - name: Run React Native Detox Android e2e Tests run: | JEST_JUNIT_OUTPUT_FILE=${{ github.workspace }}/js/react_native/e2e/android-test-results.xml \ @@ -169,6 +160,15 @@ jobs: echo "Emulator PID file was expected to exist but does not." fi + - name: Upload Android Test Results + if: always() + uses: actions/upload-artifact@v5 + with: + name: android-test-results + path: | + ${{ github.workspace }}/js/react_native/e2e/android-test-results.xml + ${{ github.workspace }}/js/react_native/e2e/artifacts + react_native_ci_ios_build: name: React Native CI iOS Build runs-on: macos-14 @@ -211,62 +211,6 @@ jobs: name: ios_pod path: ${{ runner.temp }}/ios_pod - react_native_ci_ios_unit_tests: - name: React Native CI iOS Unit Tests - needs: react_native_ci_ios_build - runs-on: macos-14 - timeout-minutes: 90 - steps: - - name: Checkout repository - uses: actions/checkout@v5 - - - name: Download iOS pod artifact - uses: actions/download-artifact@v6 - with: - name: ios_pod - path: ${{ runner.temp }}/ios_pod - - - name: Use Xcode 15.3.0 - run: sudo xcode-select --switch /Applications/Xcode_15.3.0.app/Contents/Developer - - - name: Use Node.js 22.x - uses: actions/setup-node@v6 - with: - node-version: '22.x' - - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 - with: - vcpkg-version: '2025.06.13' - vcpkg-hash: 735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc - cmake-version: '3.31.8' - cmake-hash: 99cc9c63ae49f21253efb5921de2ba84ce136018abf08632c92c060ba91d552e0f6acc214e9ba8123dee0cf6d1cf089ca389e321879fd9d719a60d975bcffcc8 - add-cmake-to-path: 'true' - disable-terrapin: 'true' - - - name: Install dependencies and bootstrap - run: | - npm ci - working-directory: ${{ github.workspace }}/js - - run: npm ci - working-directory: ${{ github.workspace }}/js/common - - run: | - set -e -x - npm ci - npm run bootstrap-no-pods - working-directory: ${{ github.workspace }}/js/react_native - - - name: Pod install - run: | - set -e -x - ls ${{ runner.temp }}/ios_pod/onnxruntime-c - ORT_C_LOCAL_POD_PATH=${{ runner.temp }}/ios_pod/onnxruntime-c pod install --verbose - working-directory: ${{ github.workspace }}/js/react_native/ios - - - name: Run React Native iOS Instrumented Tests - run: | - /usr/bin/xcodebuild -sdk iphonesimulator -configuration Debug -workspace ${{ github.workspace }}/js/react_native/ios/OnnxruntimeModule.xcworkspace -scheme OnnxruntimeModuleTest -destination 'platform=iOS Simulator,name=iPhone 15,OS=17.4' test CODE_SIGNING_ALLOWED=NO - working-directory: ${{ github.workspace }}/js/react_native/ios - react_native_ci_ios_e2e_tests: name: React Native CI iOS E2E Tests needs: react_native_ci_ios_build @@ -314,7 +258,7 @@ jobs: npm ci npm run bootstrap-no-pods working-directory: ${{ github.workspace }}/js/react_native - + - name: Pod install for e2e tests run: | set -e -x @@ -331,3 +275,12 @@ jobs: --loglevel verbose \ --take-screenshots failing working-directory: ${{ github.workspace }}/js/react_native/e2e + + - name: Upload iOS Test Results + if: always() + uses: actions/upload-artifact@v5 + with: + name: ios-test-results + path: | + ${{ github.workspace }}/js/react_native/e2e/ios-test-results.xml + ${{ github.workspace }}/js/react_native/e2e/artifacts diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index d81c5d559c8e5..d20778d56f60b 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -61,7 +61,7 @@ jobs: working-directory: ${{ github.workspace }} - name: Use .NET 8.x - uses: actions/setup-dotnet@v5 + uses: actions/setup-dotnet@v3 with: dotnet-version: '8.x' env: diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 7419decf8946f..4cbcb4d1e9d60 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -893,7 +893,9 @@ if (onnxruntime_USE_QNN OR onnxruntime_USE_QNN_INTERFACE) if (${QNN_ARCH_ABI} STREQUAL "aarch64-windows-msvc" OR ${QNN_ARCH_ABI} STREQUAL "arm64x-windows-msvc") file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/hexagon-v68/unsigned/libQnnHtpV68Skel.so" "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so" - "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libqnnhtpv73.cat") + "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libqnnhtpv73.cat" + "${onnxruntime_QNN_HOME}/lib/hexagon-v81/unsigned/libQnnHtpV81Skel.so" + "${onnxruntime_QNN_HOME}/lib/hexagon-v81/unsigned/libqnnhtpv81.cat") list(APPEND QNN_LIB_FILES ${EXTRA_HTP_LIB}) endif() message(STATUS "QNN lib files: " ${QNN_LIB_FILES}) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 434aa075e62d6..02915f2f1882e 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3937,6 +3937,7 @@ struct OrtApi { * -# "69" * -# "73" * -# "75" + * -# "81" * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). * "enable_htp_fp16_precision": Used for float32 model for HTP backend. * Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 09316966a2fd1..9503127006966 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -258,6 +258,20 @@ export declare namespace InferenceSession { */ forceCpuNodeNames?: readonly string[]; + /** + * Specify the validation mode for WebGPU execution provider. + * - 'disabled': Disable all validation. + * When used in Node.js, disable validation may cause process crash if WebGPU errors occur. Be cautious when using + * this mode. + * When used in web, this mode is equivalent to 'wgpuOnly'. + * - 'wgpuOnly': Perform WebGPU internal validation only. + * - 'basic': Perform basic validation including WebGPU internal validation. This is the default mode. + * - 'full': Perform full validation. This mode may have performance impact. Use it for debugging purpose. + * + * @default 'basic' + */ + validationMode?: 'disabled' | 'wgpuOnly' | 'basic' | 'full'; + /** * Specify an optional WebGPU device to be used by the WebGPU execution provider. */ diff --git a/js/package-lock.json b/js/package-lock.json index a13f1ae373f4b..1e9f5cb29fe6c 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -4020,9 +4020,9 @@ } }, "node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "dependencies": { "argparse": "^2.0.1" @@ -8555,9 +8555,9 @@ } }, "js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "requires": { "argparse": "^2.0.1" diff --git a/js/react_native/android/CMakeLists.txt b/js/react_native/android/CMakeLists.txt index 98f30daac6372..2f814e871ad77 100644 --- a/js/react_native/android/CMakeLists.txt +++ b/js/react_native/android/CMakeLists.txt @@ -1,37 +1,99 @@ -project(OnnxruntimeJSIHelper) +project(OnnxruntimeJSI) cmake_minimum_required(VERSION 3.9.0) -set (PACKAGE_NAME "onnxruntime-react-native") -set (BUILD_DIR ${CMAKE_SOURCE_DIR}/build) +set(PACKAGE_NAME "onnxruntime-react-native") +set(BUILD_DIR ${CMAKE_SOURCE_DIR}/build) set(CMAKE_VERBOSE_MAKEFILE ON) set(CMAKE_CXX_STANDARD 17) -file(TO_CMAKE_PATH "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi/jsi/jsi.cpp" libPath) +option(ORT_EXTENSIONS_ENABLED "Enable Ort Extensions" NO) +option(USE_NNAPI "Use NNAPI" YES) +option(USE_QNN "Use QNN" NO) + +file(GLOB libfbjni_link_DIRS "${BUILD_DIR}/fbjni-*.aar/jni/${ANDROID_ABI}") +file(GLOB libfbjni_include_DIRS "${BUILD_DIR}/fbjni-*-headers.jar/") + +file(GLOB onnxruntime_include_DIRS + "${BUILD_DIR}/onnxruntime-android-*.aar/headers") +file(GLOB onnxruntime_link_DIRS + "${BUILD_DIR}/onnxruntime-android-*.aar/jni/${ANDROID_ABI}/") + +if(ORT_EXTENSIONS_ENABLED) + add_definitions(-DORT_ENABLE_EXTENSIONS) +endif() + +if(USE_QNN) + add_definitions(-DUSE_QNN) +endif() + +if(USE_NNAPI) + add_definitions(-DUSE_NNAPI) +endif() + +file(TO_CMAKE_PATH + "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi/jsi/jsi.cpp" libPath) + +find_package(fbjni REQUIRED CONFIG) +find_package(ReactAndroid REQUIRED CONFIG) + +find_library( + onnxruntime-lib onnxruntime + PATHS ${onnxruntime_link_DIRS} + NO_CMAKE_FIND_ROOT_PATH) + +set(RN_INCLUDES + "${NODE_MODULES_DIR}/react-native/React" + "${NODE_MODULES_DIR}/react-native/React/Base" + "${NODE_MODULES_DIR}/react-native/ReactCommon" + "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi" + "${NODE_MODULES_DIR}/react-native/ReactCommon/callinvoker") + +if(${REACT_NATIVE_VERSION} VERSION_GREATER_EQUAL "0.76") + set(RN_LIBS + ReactAndroid::reactnative + ReactAndroid::jsi) +else() + list( + APPEND + RN_INCLUDES + "${NODE_MODULES_DIR}/react-native/ReactAndroid/src/main/java/com/facebook/react/turbomodule/core/jni" + ) + set(RN_LIBS + ReactAndroid::jsi + ReactAndroid::react_nativemodule_core + ReactAndroid::turbomodulejsijni) +endif() include_directories( - "${NODE_MODULES_DIR}/react-native/React" - "${NODE_MODULES_DIR}/react-native/React/Base" - "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi" -) + ../cpp + ${RN_INCLUDES} + ${onnxruntime_include_DIRS} + ${libfbjni_include_DIRS}) -add_library(onnxruntimejsihelper - SHARED - ${libPath} - src/main/cpp/cpp-adapter.cpp -) +add_library( + onnxruntimejsi SHARED + ${libPath} + src/main/cpp/cpp-adapter.cpp + ../cpp/JsiMain.cpp + ../cpp/InferenceSessionHostObject.cpp + ../cpp/JsiUtils.cpp + ../cpp/SessionUtils.cpp + ../cpp/TensorUtils.cpp) # Configure C++ 17 set_target_properties( - onnxruntimejsihelper PROPERTIES - CXX_STANDARD 17 - CXX_EXTENSIONS OFF - POSITION_INDEPENDENT_CODE ON -) + onnxruntimejsi + PROPERTIES CXX_STANDARD 17 + CXX_EXTENSIONS OFF + POSITION_INDEPENDENT_CODE ON) find_library(log-lib log) target_link_libraries( - onnxruntimejsihelper - ${log-lib} # <-- Logcat logger - android # <-- Android JNI core + onnxruntimejsi + ${onnxruntime-lib} + fbjni::fbjni + ${RN_LIBS} + ${log-lib} # <-- Logcat logger + android # <-- Android JNI core ) diff --git a/js/react_native/android/build.gradle b/js/react_native/android/build.gradle index 2f5b5adc7a1fa..41b43599a9af6 100644 --- a/js/react_native/android/build.gradle +++ b/js/react_native/android/build.gradle @@ -48,23 +48,22 @@ static def findNodeModules(baseDir) { def nodeModules = findNodeModules(projectDir); -def checkIfOrtExtensionsEnabled() { +def readPackageJsonField(field) { // locate user's project dir def reactnativeRootDir = project.rootDir.parentFile // get package.json file in root directory def packageJsonFile = new File(reactnativeRootDir, 'package.json') - // read field 'onnxruntimeExtensionsEnabled' if (packageJsonFile.exists()) { def packageJsonContents = packageJsonFile.getText() def packageJson = new groovy.json.JsonSlurper().parseText(packageJsonContents) - return packageJson.onnxruntimeExtensionsEnabled == "true" + return packageJson.get(field) } else { - logger.warn("Could not find package.json file in the expected directory: ${reactnativeRootDir}. ONNX Runtime Extensions will not be enabled.") + logger.warn("Could not find package.json file in the expected directory: ${reactnativeRootDir}. ${field} will not be enabled.") } - return false } -boolean ortExtensionsEnabled = checkIfOrtExtensionsEnabled() +boolean ortExtensionsEnabled = readPackageJsonField('onnxruntimeExtensionsEnabled') == "true" +boolean useQnn = readPackageJsonField('onnxruntimeUseQnn') == "true" def REACT_NATIVE_VERSION = ['node', '--print', "JSON.parse(require('fs').readFileSync(require.resolve('react-native/package.json'), 'utf-8')).version"].execute(null, rootDir).text.trim() def REACT_NATIVE_MINOR_VERSION = REACT_NATIVE_VERSION.split("\\.")[1].toInteger() @@ -85,9 +84,18 @@ android { cppFlags "-O2 -frtti -fexceptions -Wall -Wno-unused-variable -fstack-protector-all" if (REACT_NATIVE_MINOR_VERSION >= 71) { // fabricjni required c++_shared - arguments "-DANDROID_STL=c++_shared", "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}" + arguments "-DANDROID_STL=c++_shared", + "-DNODE_MODULES_DIR=${nodeModules}", + "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}", + "-DREACT_NATIVE_VERSION=${REACT_NATIVE_VERSION}", + "-DUSE_QNN=${useQnn}", + "-DUSE_NNAPI=${!useQnn}" } else { - arguments "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}" + arguments "-DNODE_MODULES_DIR=${nodeModules}", + "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}", + "-DREACT_NATIVE_VERSION=${REACT_NATIVE_VERSION}", + "-DUSE_QNN=${useQnn}", + "-DUSE_NNAPI=${!useQnn}" } abiFilters (*reactNativeArchitectures()) } @@ -119,6 +127,9 @@ android { "META-INF", "META-INF/**", "**/libjsi.so", + "**/libfbjni.so", + "**/libreact_nativemodule_core.so", + "**/libturbomodulejsijni.so" ] } @@ -147,6 +158,10 @@ android { } } } + + configurations { + extractLibs + } } repositories { @@ -217,10 +232,6 @@ repositories { "Ensure you have you installed React Native as a dependency in your project and try again." ) } - - flatDir { - dir 'libs' - } } dependencies { @@ -228,16 +239,47 @@ dependencies { implementation "com.facebook.react:react-android:"+ REACT_NATIVE_VERSION api "org.mockito:mockito-core:2.28.2" - androidTestImplementation "androidx.test:runner:1.5.2" - androidTestImplementation "androidx.test:rules:1.5.0" implementation "junit:junit:4.12" - androidTestImplementation "com.linkedin.dexmaker:dexmaker-mockito-inline-extended:2.28.1" + if (useQnn) { + extractLibs "com.microsoft.onnxruntime:onnxruntime-android-qnn:latest.integration@aar" + } else { + extractLibs "com.microsoft.onnxruntime:onnxruntime-android:latest.integration@aar" + } - implementation "com.microsoft.onnxruntime:onnxruntime-android:latest.integration@aar" + if (VersionNumber.parse(REACT_NATIVE_VERSION) < VersionNumber.parse("0.71")) { + extractLibs "com.facebook.fbjni:fbjni:+:headers" + extractLibs "com.facebook.fbjni:fbjni:+" + } // By default it will just include onnxruntime full aar package if (ortExtensionsEnabled) { implementation "com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.integration@aar" } -} \ No newline at end of file +} + +task extractLibs { + doLast { + configurations.extractLibs.files.each { + def file = it.absoluteFile + copy { + from zipTree(file) + into "$buildDir/$file.name" + include "**/*.h", "**/*.so" + } + } + } +} + +def nativeBuildDependsOn(dependsOnTask, variant) { + def buildTasks = tasks.findAll({ task -> + !task.name.contains("Clean") && (task.name.contains("externalNative") || task.name.contains("CMake")) }) + if (variant != null) { + buildTasks = buildTasks.findAll({ task -> task.name.contains(variant) }) + } + buildTasks.forEach { task -> task.dependsOn(dependsOnTask) } +} + +afterEvaluate { + nativeBuildDependsOn(extractLibs, null) +} diff --git a/js/react_native/android/src/androidTest/Readme.md b/js/react_native/android/src/androidTest/Readme.md deleted file mode 100644 index b0376602af908..0000000000000 --- a/js/react_native/android/src/androidTest/Readme.md +++ /dev/null @@ -1 +0,0 @@ -Please see [here](/js/react_native/test_types_models.readme.md) for information on the test models. diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/FakeBlobModule.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/FakeBlobModule.java deleted file mode 100644 index 82d063ad51e3f..0000000000000 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/FakeBlobModule.java +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package ai.onnxruntime.reactnative; - -import com.facebook.react.bridge.Arguments; -import com.facebook.react.bridge.JavaOnlyMap; -import com.facebook.react.bridge.ReactApplicationContext; -import com.facebook.react.bridge.ReadableMap; -import com.facebook.react.modules.blob.BlobModule; - -public class FakeBlobModule extends BlobModule { - - public FakeBlobModule(ReactApplicationContext context) { super(null); } - - @Override - public String getName() { - return "BlobModule"; - } - - public JavaOnlyMap testCreateData(byte[] bytes) { - String blobId = store(bytes); - JavaOnlyMap data = new JavaOnlyMap(); - data.putString("blobId", blobId); - data.putInt("offset", 0); - data.putInt("size", bytes.length); - return data; - } - - public byte[] testGetData(ReadableMap data) { - String blobId = data.getString("blobId"); - int offset = data.getInt("offset"); - int size = data.getInt("size"); - return resolve(blobId, offset, size); - } -} diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/OnnxruntimeModuleTest.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/OnnxruntimeModuleTest.java deleted file mode 100644 index b15b1a468ae29..0000000000000 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/OnnxruntimeModuleTest.java +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package ai.onnxruntime.reactnative; - -import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession; -import static org.mockito.Mockito.when; - -import ai.onnxruntime.TensorInfo; -import android.util.Base64; -import androidx.test.platform.app.InstrumentationRegistry; -import com.facebook.react.bridge.Arguments; -import com.facebook.react.bridge.CatalystInstance; -import com.facebook.react.bridge.JavaOnlyArray; -import com.facebook.react.bridge.JavaOnlyMap; -import com.facebook.react.bridge.ReactApplicationContext; -import com.facebook.react.bridge.ReadableArray; -import com.facebook.react.bridge.ReadableMap; -import com.facebook.react.bridge.WritableMap; -import com.facebook.react.modules.blob.BlobModule; -import java.io.ByteArrayOutputStream; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.FloatBuffer; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.mockito.MockitoSession; - -public class OnnxruntimeModuleTest { - private ReactApplicationContext reactContext = - new ReactApplicationContext(InstrumentationRegistry.getInstrumentation().getContext()); - - private FakeBlobModule blobModule; - - private static byte[] getInputModelBuffer(InputStream modelStream) throws Exception { - ByteArrayOutputStream byteBuffer = new ByteArrayOutputStream(); - - int bufferSize = 1024; - byte[] buffer = new byte[bufferSize]; - - int len; - while ((len = modelStream.read(buffer)) != -1) { - byteBuffer.write(buffer, 0, len); - } - - byte[] modelBuffer = byteBuffer.toByteArray(); - - return modelBuffer; - } - - @Before - public void setUp() { - blobModule = new FakeBlobModule(reactContext); - } - - @Test - public void getName() throws Exception { - OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext); - ortModule.blobModule = blobModule; - String name = "Onnxruntime"; - Assert.assertEquals(ortModule.getName(), name); - } - - @Test - public void onnxruntime_module() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext); - ortModule.blobModule = blobModule; - String sessionKey = ""; - - // test loadModel() - { - try (InputStream modelStream = - reactContext.getResources().openRawResource(ai.onnxruntime.reactnative.test.R.raw.test_types_float);) { - byte[] modelBuffer = getInputModelBuffer(modelStream); - - JavaOnlyMap options = new JavaOnlyMap(); - try { - ReadableMap resultMap = ortModule.loadModel(modelBuffer, options); - sessionKey = resultMap.getString("key"); - ReadableArray inputNames = resultMap.getArray("inputNames"); - ReadableArray outputNames = resultMap.getArray("outputNames"); - - Assert.assertEquals(inputNames.size(), 1); - Assert.assertEquals(inputNames.getString(0), "input"); - Assert.assertEquals(outputNames.size(), 1); - Assert.assertEquals(outputNames.getString(0), "output"); - } catch (Exception e) { - Assert.fail(e.getMessage()); - } - } - } - - int[] dims = new int[] {1, 5}; - float[] inputData = new float[] {1.0f, 2.0f, -3.0f, Float.MIN_VALUE, Float.MAX_VALUE}; - - // test run() - { - JavaOnlyMap inputDataMap = new JavaOnlyMap(); - { - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dimsArray = new JavaOnlyArray(); - for (int dim : dims) { - dimsArray.pushInt(dim); - } - inputTensorMap.putArray("dims", dimsArray); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeFloat); - - ByteBuffer buffer = ByteBuffer.allocate(5 * Float.BYTES).order(ByteOrder.nativeOrder()); - FloatBuffer floatBuffer = buffer.asFloatBuffer(); - for (float value : inputData) { - floatBuffer.put(value); - } - floatBuffer.rewind(); - inputTensorMap.putMap("data", blobModule.testCreateData(buffer.array())); - - inputDataMap.putMap("input", inputTensorMap); - } - - JavaOnlyArray outputNames = new JavaOnlyArray(); - outputNames.pushString("output"); - - JavaOnlyMap options = new JavaOnlyMap(); - options.putBoolean("encodeTensorData", true); - - try { - ReadableMap resultMap = ortModule.run(sessionKey, inputDataMap, outputNames, options); - - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeFloat); - ReadableMap data = outputMap.getMap("data"); - FloatBuffer buffer = - ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asFloatBuffer(); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f); - } - } catch (Exception e) { - Assert.fail(e.getMessage()); - } - } - - // test dispose - ortModule.dispose(sessionKey); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void onnxruntime_module_append_nnapi() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext); - ortModule.blobModule = blobModule; - String sessionKey = ""; - - // test loadModel() with nnapi ep options - - try (InputStream modelStream = - reactContext.getResources().openRawResource(ai.onnxruntime.reactnative.test.R.raw.test_types_float);) { - - byte[] modelBuffer = getInputModelBuffer(modelStream); - - // register with nnapi ep options - JavaOnlyMap options = new JavaOnlyMap(); - JavaOnlyArray epArray = new JavaOnlyArray(); - epArray.pushString("nnapi"); - options.putArray("executionProviders", epArray); - - try { - ReadableMap resultMap = ortModule.loadModel(modelBuffer, options); - sessionKey = resultMap.getString("key"); - ReadableArray inputNames = resultMap.getArray("inputNames"); - ReadableArray outputNames = resultMap.getArray("outputNames"); - - Assert.assertEquals(inputNames.size(), 1); - Assert.assertEquals(inputNames.getString(0), "input"); - Assert.assertEquals(outputNames.size(), 1); - Assert.assertEquals(outputNames.getString(0), "output"); - } catch (Exception e) { - Assert.fail(e.getMessage()); - } - } - ortModule.dispose(sessionKey); - } finally { - mockSession.finishMocking(); - } - } -} diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java deleted file mode 100644 index 72518488e6682..0000000000000 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java +++ /dev/null @@ -1,565 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package ai.onnxruntime.reactnative; - -import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession; -import static org.mockito.Mockito.when; - -import ai.onnxruntime.OnnxJavaType; -import ai.onnxruntime.OnnxTensor; -import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtSession; -import ai.onnxruntime.OrtUtil; -import ai.onnxruntime.TensorInfo; -import android.content.Context; -import android.util.Base64; -import androidx.test.filters.SmallTest; -import androidx.test.platform.app.InstrumentationRegistry; -import com.facebook.react.bridge.Arguments; -import com.facebook.react.bridge.JavaOnlyArray; -import com.facebook.react.bridge.JavaOnlyMap; -import com.facebook.react.bridge.ReactApplicationContext; -import com.facebook.react.bridge.ReadableMap; -import com.facebook.react.modules.blob.BlobModule; -import java.io.ByteArrayOutputStream; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; -import java.nio.LongBuffer; -import java.nio.ShortBuffer; -import java.util.HashMap; -import java.util.Map; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.mockito.MockitoSession; - -@SmallTest -public class TensorHelperTest { - private ReactApplicationContext reactContext = - new ReactApplicationContext(InstrumentationRegistry.getInstrumentation().getContext()); - - private OrtEnvironment ortEnvironment; - - private FakeBlobModule blobModule; - - @Before - public void setUp() { - ortEnvironment = OrtEnvironment.getEnvironment("TensorHelperTest"); - blobModule = new FakeBlobModule(reactContext); - } - - @Test - public void createInputTensor_float32() throws Exception { - OnnxTensor outputTensor = - OnnxTensor.createTensor(ortEnvironment, new float[] {Float.MIN_VALUE, 2.0f, Float.MAX_VALUE}); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(3); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeFloat); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(3 * 4).order(ByteOrder.nativeOrder()); - FloatBuffer dataFloatBuffer = dataByteBuffer.asFloatBuffer(); - dataFloatBuffer.put(Float.MIN_VALUE); - dataFloatBuffer.put(2.0f); - dataFloatBuffer.put(Float.MAX_VALUE); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getFloatBuffer().array(), outputTensor.getFloatBuffer().array(), 1e-6f); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createInputTensor_int8() throws Exception { - OnnxTensor outputTensor = OnnxTensor.createTensor(ortEnvironment, new byte[] {Byte.MIN_VALUE, 2, Byte.MAX_VALUE}); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(3); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeByte); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(3); - dataByteBuffer.put(Byte.MIN_VALUE); - dataByteBuffer.put((byte)2); - dataByteBuffer.put(Byte.MAX_VALUE); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8); - Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getByteBuffer().array(), outputTensor.getByteBuffer().array()); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createInputTensor_uint8() throws Exception { - OnnxTensor outputTensor = OnnxTensor.createTensor(ortEnvironment, ByteBuffer.wrap(new byte[] {0, 2, (byte)255}), - new long[] {3}, OnnxJavaType.UINT8); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(3); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeUnsignedByte); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(3); - dataByteBuffer.put((byte)0); - dataByteBuffer.put((byte)2); - dataByteBuffer.put((byte)255); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); - Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getByteBuffer().array(), outputTensor.getByteBuffer().array()); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createInputTensor_int32() throws Exception { - OnnxTensor outputTensor = - OnnxTensor.createTensor(ortEnvironment, new int[] {Integer.MIN_VALUE, 2, Integer.MAX_VALUE}); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(3); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeInt); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(3 * 4).order(ByteOrder.nativeOrder()); - IntBuffer dataIntBuffer = dataByteBuffer.asIntBuffer(); - dataIntBuffer.put(Integer.MIN_VALUE); - dataIntBuffer.put(2); - dataIntBuffer.put(Integer.MAX_VALUE); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32); - Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getIntBuffer().array(), outputTensor.getIntBuffer().array()); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createInputTensor_int64() throws Exception { - OnnxTensor outputTensor = - OnnxTensor.createTensor(ortEnvironment, new long[] {Long.MIN_VALUE, 15000000001L, Long.MAX_VALUE}); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(3); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeLong); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(3 * 8).order(ByteOrder.nativeOrder()); - LongBuffer dataLongBuffer = dataByteBuffer.asLongBuffer(); - dataLongBuffer.put(Long.MIN_VALUE); - dataLongBuffer.put(15000000001L); - dataLongBuffer.put(Long.MAX_VALUE); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); - Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getLongBuffer().array(), outputTensor.getLongBuffer().array()); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createInputTensor_double() throws Exception { - OnnxTensor outputTensor = - OnnxTensor.createTensor(ortEnvironment, new double[] {Double.MIN_VALUE, 1.8e+30, Double.MAX_VALUE}); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(3); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeDouble); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(3 * 8).order(ByteOrder.nativeOrder()); - DoubleBuffer dataDoubleBuffer = dataByteBuffer.asDoubleBuffer(); - dataDoubleBuffer.put(Double.MIN_VALUE); - dataDoubleBuffer.put(1.8e+30); - dataDoubleBuffer.put(Double.MAX_VALUE); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE); - Assert.assertEquals(outputTensor.getInfo().onnxType, - TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getDoubleBuffer().array(), outputTensor.getDoubleBuffer().array(), 1e-6f); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createInputTensor_bool() throws Exception { - OnnxTensor outputTensor = OnnxTensor.createTensor(ortEnvironment, new boolean[] {false, true}); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(2); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeBool); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(2); - dataByteBuffer.put((byte)0); - dataByteBuffer.put((byte)1); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL); - Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getByteBuffer().array(), outputTensor.getByteBuffer().array()); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createOutputTensor_bool() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_bool); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - boolean[] inputData = new boolean[] {true, false, false, true, false}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - Object tensorInput = OrtUtil.reshape(inputData, dims); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, tensorInput); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeBool); - ReadableMap data = outputMap.getMap("data"); - ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i) == 1, inputData[i]); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void createOutputTensor_double() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_double); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - double[] inputData = new double[] {1.0f, 2.0f, -3.0f, Double.MIN_VALUE, Double.MAX_VALUE}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - Object tensorInput = OrtUtil.reshape(inputData, dims); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, tensorInput); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeDouble); - ReadableMap data = outputMap.getMap("data"); - DoubleBuffer buffer = - ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asDoubleBuffer(); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void createOutputTensor_float() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_float); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - float[] inputData = new float[] {1.0f, 2.0f, -3.0f, Float.MIN_VALUE, Float.MAX_VALUE}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - Object tensorInput = OrtUtil.reshape(inputData, dims); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, tensorInput); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeFloat); - ReadableMap data = outputMap.getMap("data"); - FloatBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asFloatBuffer(); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void createOutputTensor_int8() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_int8); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - byte[] inputData = new byte[] {1, 2, -3, Byte.MAX_VALUE, Byte.MAX_VALUE}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - Object tensorInput = OrtUtil.reshape(inputData, dims); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, tensorInput); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeByte); - ReadableMap data = outputMap.getMap("data"); - ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i]); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void createOutputTensor_int32() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_int32); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - int[] inputData = new int[] {1, 2, -3, Integer.MIN_VALUE, Integer.MAX_VALUE}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - Object tensorInput = OrtUtil.reshape(inputData, dims); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, tensorInput); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeInt); - ReadableMap data = outputMap.getMap("data"); - IntBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asIntBuffer(); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i]); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void createOutputTensor_int64() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_int64); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - long[] inputData = new long[] {1, 2, -3, Long.MIN_VALUE, Long.MAX_VALUE}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - Object tensorInput = OrtUtil.reshape(inputData, dims); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, tensorInput); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeLong); - ReadableMap data = outputMap.getMap("data"); - LongBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asLongBuffer(); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i]); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void createOutputTensor_uint8() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_uint8); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - byte[] inputData = new byte[] {1, 2, -3, Byte.MAX_VALUE, Byte.MAX_VALUE}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - ByteBuffer inputBuffer = ByteBuffer.wrap(inputData); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, inputBuffer, dims, OnnxJavaType.UINT8); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeUnsignedByte); - ReadableMap data = outputMap.getMap("data"); - ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i]); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - private byte[] readBytesFromResourceFile(int resourceId) throws Exception { - Context context = InstrumentationRegistry.getInstrumentation().getContext(); - InputStream inputStream = context.getResources().openRawResource(resourceId); - ByteArrayOutputStream byteBuffer = new ByteArrayOutputStream(); - - int bufferSize = 1024; - byte[] buffer = new byte[bufferSize]; - - int len; - while ((len = inputStream.read(buffer)) != -1) { - byteBuffer.write(buffer, 0, len); - } - - return byteBuffer.toByteArray(); - } -} diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_float.ort b/js/react_native/android/src/androidTest/res/raw/test_types_float.ort deleted file mode 100644 index e5c40742843d5..0000000000000 Binary files a/js/react_native/android/src/androidTest/res/raw/test_types_float.ort and /dev/null differ diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_int32.ort b/js/react_native/android/src/androidTest/res/raw/test_types_int32.ort deleted file mode 100644 index 6135c9a4aca7c..0000000000000 Binary files a/js/react_native/android/src/androidTest/res/raw/test_types_int32.ort and /dev/null differ diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_int64.ort b/js/react_native/android/src/androidTest/res/raw/test_types_int64.ort deleted file mode 100644 index a9892d9ec598d..0000000000000 Binary files a/js/react_native/android/src/androidTest/res/raw/test_types_int64.ort and /dev/null differ diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_int8.ort b/js/react_native/android/src/androidTest/res/raw/test_types_int8.ort deleted file mode 100644 index f1bf199e488e1..0000000000000 Binary files a/js/react_native/android/src/androidTest/res/raw/test_types_int8.ort and /dev/null differ diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_uint8.ort b/js/react_native/android/src/androidTest/res/raw/test_types_uint8.ort deleted file mode 100644 index 9f5310803323a..0000000000000 Binary files a/js/react_native/android/src/androidTest/res/raw/test_types_uint8.ort and /dev/null differ diff --git a/js/react_native/android/src/main/cpp/cpp-adapter.cpp b/js/react_native/android/src/main/cpp/cpp-adapter.cpp index d75a2f9c99d8b..50434b71ec2ed 100644 --- a/js/react_native/android/src/main/cpp/cpp-adapter.cpp +++ b/js/react_native/android/src/main/cpp/cpp-adapter.cpp @@ -1,126 +1,43 @@ +#include "JsiMain.h" +#include +#include +#include +#include #include #include -#include using namespace facebook; -typedef u_int8_t byte; - -std::string jstring2string(JNIEnv* env, jstring jStr) { - if (!jStr) return ""; - - jclass stringClass = env->GetObjectClass(jStr); - jmethodID getBytes = env->GetMethodID(stringClass, "getBytes", "(Ljava/lang/String;)[B"); - const auto stringJbytes = (jbyteArray)env->CallObjectMethod(jStr, getBytes, env->NewStringUTF("UTF-8")); - - auto length = (size_t)env->GetArrayLength(stringJbytes); - jbyte* pBytes = env->GetByteArrayElements(stringJbytes, nullptr); - - std::string ret = std::string((char*)pBytes, length); - env->ReleaseByteArrayElements(stringJbytes, pBytes, JNI_ABORT); - - env->DeleteLocalRef(stringJbytes); - env->DeleteLocalRef(stringClass); - return ret; -} - -byte* getBytesFromBlob(JNIEnv* env, jobject instanceGlobal, const std::string& blobId, int offset, int size) { - if (!env) throw std::runtime_error("JNI Environment is gone!"); - - // get java class - jclass clazz = env->GetObjectClass(instanceGlobal); - // get method in java class - jmethodID getBufferJava = env->GetMethodID(clazz, "getBlobBuffer", "(Ljava/lang/String;II)[B"); - // call method - auto jstring = env->NewStringUTF(blobId.c_str()); - auto boxedBytes = (jbyteArray)env->CallObjectMethod(instanceGlobal, - getBufferJava, - // arguments - jstring, - offset, - size); - env->DeleteLocalRef(jstring); - - jboolean isCopy = true; - jbyte* bytes = env->GetByteArrayElements(boxedBytes, &isCopy); - env->DeleteLocalRef(boxedBytes); - return reinterpret_cast(bytes); -}; - -std::string createBlob(JNIEnv* env, jobject instanceGlobal, byte* bytes, size_t size) { - if (!env) throw std::runtime_error("JNI Environment is gone!"); - - // get java class - jclass clazz = env->GetObjectClass(instanceGlobal); - // get method in java class - jmethodID getBufferJava = env->GetMethodID(clazz, "createBlob", "([B)Ljava/lang/String;"); - // call method - auto byteArray = env->NewByteArray(size); - env->SetByteArrayRegion(byteArray, 0, size, reinterpret_cast(bytes)); - auto blobId = (jstring)env->CallObjectMethod(instanceGlobal, getBufferJava, byteArray); - env->DeleteLocalRef(byteArray); - - return jstring2string(env, blobId); +static std::shared_ptr env; + +class OnnxruntimeModule + : public jni::JavaClass { + public: + static constexpr auto kJavaDescriptor = + "Lai/onnxruntime/reactnative/OnnxruntimeModule;"; + + static void registerNatives() { + javaClassStatic()->registerNatives( + {makeNativeMethod("nativeInstall", + OnnxruntimeModule::nativeInstall), + makeNativeMethod("nativeCleanup", + OnnxruntimeModule::nativeCleanup)}); + } + + private: + static void nativeInstall(jni::alias_ref thiz, + jlong jsContextNativePointer, + jni::alias_ref + jsCallInvokerHolder) { + auto runtime = reinterpret_cast(jsContextNativePointer); + auto jsCallInvoker = jsCallInvokerHolder->cthis()->getCallInvoker(); + env = onnxruntimejsi::install(*runtime, jsCallInvoker); + } + + static void nativeCleanup(jni::alias_ref thiz) { env.reset(); } }; -extern "C" JNIEXPORT void JNICALL -Java_ai_onnxruntime_reactnative_OnnxruntimeJSIHelper_nativeInstall(JNIEnv* env, jclass _, jlong jsiPtr, jobject instance) { - auto jsiRuntime = reinterpret_cast(jsiPtr); - - auto& runtime = *jsiRuntime; - - auto instanceGlobal = env->NewGlobalRef(instance); - - auto resolveArrayBuffer = jsi::Function::createFromHostFunction(runtime, - jsi::PropNameID::forAscii(runtime, "jsiOnnxruntimeResolveArrayBuffer"), - 1, - [=](jsi::Runtime& runtime, - const jsi::Value& thisValue, - const jsi::Value* arguments, - size_t count) -> jsi::Value { - if (count != 1) { - throw jsi::JSError(runtime, "jsiOnnxruntimeResolveArrayBuffer(..) expects one argument (object)!"); - } - - jsi::Object data = arguments[0].asObject(runtime); - auto blobId = data.getProperty(runtime, "blobId").asString(runtime); - auto offset = data.getProperty(runtime, "offset").asNumber(); - auto size = data.getProperty(runtime, "size").asNumber(); - - auto bytes = getBytesFromBlob(env, instanceGlobal, blobId.utf8(runtime), offset, size); - - size_t totalSize = size - offset; - jsi::Function arrayBufferCtor = runtime.global().getPropertyAsFunction(runtime, "ArrayBuffer"); - jsi::Object o = arrayBufferCtor.callAsConstructor(runtime, (int)totalSize).getObject(runtime); - jsi::ArrayBuffer buf = o.getArrayBuffer(runtime); - memcpy(buf.data(runtime), reinterpret_cast(bytes), totalSize); - - return buf; - }); - runtime.global().setProperty(runtime, "jsiOnnxruntimeResolveArrayBuffer", std::move(resolveArrayBuffer)); - - auto storeArrayBuffer = jsi::Function::createFromHostFunction(runtime, - jsi::PropNameID::forAscii(runtime, "jsiOnnxruntimeStoreArrayBuffer"), - 1, - [=](jsi::Runtime& runtime, - const jsi::Value& thisValue, - const jsi::Value* arguments, - size_t count) -> jsi::Value { - if (count != 1) { - throw jsi::JSError(runtime, "jsiOnnxruntimeStoreArrayBuffer(..) expects one argument (object)!"); - } - - auto arrayBuffer = arguments[0].asObject(runtime).getArrayBuffer(runtime); - auto size = arrayBuffer.size(runtime); - - std::string blobId = createBlob(env, instanceGlobal, arrayBuffer.data(runtime), size); - - jsi::Object result(runtime); - auto blobIdString = jsi::String::createFromUtf8(runtime, blobId); - result.setProperty(runtime, "blobId", blobIdString); - result.setProperty(runtime, "offset", jsi::Value(0)); - result.setProperty(runtime, "size", jsi::Value(static_cast(size))); - return result; - }); - runtime.global().setProperty(runtime, "jsiOnnxruntimeStoreArrayBuffer", std::move(storeArrayBuffer)); +JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) { + return jni::initialize( + vm, [] { OnnxruntimeModule::registerNatives(); }); } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsDisabled.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsDisabled.java index de4c880981881..cacc382e29230 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsDisabled.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsDisabled.java @@ -3,14 +3,13 @@ package ai.onnxruntime.reactnative; -import ai.onnxruntime.OrtSession.SessionOptions; import android.util.Log; class OnnxruntimeExtensions { - public void registerOrtExtensionsIfEnabled(SessionOptions sessionOptions) { + static public String getLibraryPath() { Log.i("OnnxruntimeExtensions", "ORT Extensions is not enabled in the current configuration. If you want to enable this support, " + "please add \"onnxruntimeEnableExtensions\": \"true\" in your project root directory package.json."); - return; + return null; } } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsEnabled.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsEnabled.java index 9bbf41c8f1671..d41163fdb53e9 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsEnabled.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsEnabled.java @@ -3,12 +3,10 @@ package ai.onnxruntime.reactnative; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.extensions.OrtxPackage; class OnnxruntimeExtensions { - public void registerOrtExtensionsIfEnabled(SessionOptions sessionOptions) throws OrtException { - sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath()); + static public String getLibraryPath() { + return OrtxPackage.getLibraryPath(); } } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeJSIHelper.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeJSIHelper.java deleted file mode 100644 index 93b37df0768b4..0000000000000 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeJSIHelper.java +++ /dev/null @@ -1,70 +0,0 @@ -package ai.onnxruntime.reactnative; - -import androidx.annotation.NonNull; -import com.facebook.react.bridge.JavaScriptContextHolder; -import com.facebook.react.bridge.ReactApplicationContext; -import com.facebook.react.bridge.ReactContextBaseJavaModule; -import com.facebook.react.bridge.ReactMethod; -import com.facebook.react.module.annotations.ReactModule; -import com.facebook.react.modules.blob.BlobModule; - -@ReactModule(name = OnnxruntimeJSIHelper.NAME) -public class OnnxruntimeJSIHelper extends ReactContextBaseJavaModule { - public static final String NAME = "OnnxruntimeJSIHelper"; - - private static ReactApplicationContext reactContext; - protected BlobModule blobModule; - - public OnnxruntimeJSIHelper(ReactApplicationContext context) { - super(context); - reactContext = context; - } - - @Override - @NonNull - public String getName() { - return NAME; - } - - public void checkBlobModule() { - if (blobModule == null) { - blobModule = getReactApplicationContext().getNativeModule(BlobModule.class); - if (blobModule == null) { - throw new RuntimeException("BlobModule is not initialized"); - } - } - } - - @ReactMethod(isBlockingSynchronousMethod = true) - public boolean install() { - try { - System.loadLibrary("onnxruntimejsihelper"); - JavaScriptContextHolder jsContext = getReactApplicationContext().getJavaScriptContextHolder(); - nativeInstall(jsContext.get(), this); - return true; - } catch (Exception exception) { - return false; - } - } - - public byte[] getBlobBuffer(String blobId, int offset, int size) { - checkBlobModule(); - byte[] bytes = blobModule.resolve(blobId, offset, size); - blobModule.remove(blobId); - if (bytes == null) { - throw new RuntimeException("Failed to resolve Blob #" + blobId + "! Not found."); - } - return bytes; - } - - public String createBlob(byte[] buffer) { - checkBlobModule(); - String blobId = blobModule.store(buffer); - if (blobId == null) { - throw new RuntimeException("Failed to create Blob!"); - } - return blobId; - } - - public static native void nativeInstall(long jsiPointer, OnnxruntimeJSIHelper instance); -} diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java index 496db5a6087e6..c362e6ad71bbe 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java @@ -3,65 +3,21 @@ package ai.onnxruntime.reactnative; -import ai.onnxruntime.OnnxTensor; -import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtLoggingLevel; -import ai.onnxruntime.OrtSession; -import ai.onnxruntime.OrtSession.Result; -import ai.onnxruntime.OrtSession.RunOptions; -import ai.onnxruntime.OrtSession.SessionOptions; -import ai.onnxruntime.providers.NNAPIFlags; -import android.net.Uri; +import java.util.Map; +import java.util.HashMap; import android.os.Build; -import android.util.Log; import androidx.annotation.NonNull; import androidx.annotation.RequiresApi; -import com.facebook.react.bridge.Arguments; -import com.facebook.react.bridge.LifecycleEventListener; -import com.facebook.react.bridge.Promise; +import com.facebook.react.bridge.JavaScriptContextHolder; +import com.facebook.react.bridge.ReactMethod; import com.facebook.react.bridge.ReactApplicationContext; import com.facebook.react.bridge.ReactContextBaseJavaModule; -import com.facebook.react.bridge.ReactMethod; -import com.facebook.react.bridge.ReadableArray; -import com.facebook.react.bridge.ReadableMap; -import com.facebook.react.bridge.ReadableType; -import com.facebook.react.bridge.WritableArray; -import com.facebook.react.bridge.WritableMap; -import com.facebook.react.modules.blob.BlobModule; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.io.Reader; -import java.math.BigInteger; -import java.util.Collections; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; +import com.facebook.react.turbomodule.core.CallInvokerHolderImpl; @RequiresApi(api = Build.VERSION_CODES.N) -public class OnnxruntimeModule extends ReactContextBaseJavaModule implements LifecycleEventListener { +public class OnnxruntimeModule extends ReactContextBaseJavaModule { private static ReactApplicationContext reactContext; - private static OrtEnvironment ortEnvironment = OrtEnvironment.getEnvironment(); - private static Map sessionMap = new HashMap<>(); - - private static BigInteger nextSessionId = new BigInteger("0"); - private static String getNextSessionKey() { - String key = nextSessionId.toString(); - nextSessionId = nextSessionId.add(BigInteger.valueOf(1)); - return key; - } - - protected BlobModule blobModule; - public OnnxruntimeModule(ReactApplicationContext context) { super(context); reactContext = context; @@ -73,393 +29,37 @@ public String getName() { return "Onnxruntime"; } - public void checkBlobModule() { - if (blobModule == null) { - blobModule = getReactApplicationContext().getNativeModule(BlobModule.class); - if (blobModule == null) { - throw new RuntimeException("BlobModule is not initialized"); - } - } - } + native void nativeInstall(long jsiPointer, CallInvokerHolderImpl jsCallInvokerHolder); - /** - * React native binding API to load a model using given uri. - * - * @param uri a model file location - * @param options onnxruntime session options - * @param promise output returning back to react native js - * @note the value provided to `promise` includes a key representing the session. - * when run() is called, the key must be passed into the first parameter. - */ - @ReactMethod - public void loadModel(String uri, ReadableMap options, Promise promise) { - try { - WritableMap resultMap = loadModel(uri, options); - promise.resolve(resultMap); - } catch (Exception e) { - promise.reject("Failed to load model \"" + uri + "\": " + e.getMessage(), e); - } - } + native void nativeCleanup(); - /** - * React native binding API to load a model using blob object that data stored in BlobModule. - * - * @param data the blob object - * @param options onnxruntime session options - * @param promise output returning back to react native js - * @note the value provided to `promise` includes a key representing the session. - * when run() is called, the key must be passed into the first parameter. - */ - @ReactMethod - public void loadModelFromBlob(ReadableMap data, ReadableMap options, Promise promise) { - try { - checkBlobModule(); - String blobId = data.getString("blobId"); - byte[] bytes = blobModule.resolve(blobId, data.getInt("offset"), data.getInt("size")); - blobModule.remove(blobId); - WritableMap resultMap = loadModel(bytes, options); - promise.resolve(resultMap); - } catch (Exception e) { - promise.reject("Failed to load model from buffer: " + e.getMessage(), e); - } - } - - /** - * React native binding API to dispose a session. - * - * @param key session key representing a session given at loadModel() - * @param promise output returning back to react native js - */ - @ReactMethod - public void dispose(String key, Promise promise) { - try { - dispose(key); - promise.resolve(null); - } catch (OrtException e) { - promise.reject("Failed to dispose session: " + e.getMessage(), e); - } + @Override + public void invalidate() { + super.invalidate(); + nativeCleanup(); } /** - * React native binding API to run a model using given uri. - * - * @param key session key representing a session given at loadModel() - * @param input an input tensor - * @param output an output names to be returned - * @param options onnxruntime run options - * @param promise output returning back to react native js + * Install onnxruntime JSI API */ - @ReactMethod - public void run(String key, ReadableMap input, ReadableArray output, ReadableMap options, Promise promise) { + @ReactMethod(isBlockingSynchronousMethod = true) + public boolean install() { try { - WritableMap resultMap = run(key, input, output, options); - promise.resolve(resultMap); + System.loadLibrary("onnxruntimejsi"); + JavaScriptContextHolder jsContext = getReactApplicationContext().getJavaScriptContextHolder(); + CallInvokerHolderImpl jsCallInvokerHolder = + (CallInvokerHolderImpl) getReactApplicationContext().getCatalystInstance().getJSCallInvokerHolder(); + nativeInstall(jsContext.get(), jsCallInvokerHolder); + return true; } catch (Exception e) { - promise.reject("Fail to inference: " + e.getMessage(), e); - } - } - - /** - * Load a model from raw resource directory. - * - * @param uri uri parameter from react native loadModel() - * @param options onnxruntime session options - * @return model loading information, such as key, input names, and output names - */ - public WritableMap loadModel(String uri, ReadableMap options) throws Exception { - return loadModelImpl(uri, null, options); - } - - /** - * Load a model from buffer. - * - * @param modelData the model data buffer - * @param options onnxruntime session options - * @return model loading information, such as key, input names, and output names - */ - public WritableMap loadModel(byte[] modelData, ReadableMap options) throws Exception { - return loadModelImpl("", modelData, options); - } - - /** - * Load model implementation method for either from model path or model data buffer. - * - * @param uri uri parameter from react native loadModel() - * @param modelData model data buffer - * @param options onnxruntime session options - * @return model loading information map, such as key, input names, and output names - */ - private WritableMap loadModelImpl(String uri, byte[] modelData, ReadableMap options) throws Exception { - OrtSession ortSession; - SessionOptions sessionOptions = parseSessionOptions(options); - - // optional call for registering custom ops when ort extensions enabled - OnnxruntimeExtensions ortExt = new OnnxruntimeExtensions(); - ortExt.registerOrtExtensionsIfEnabled(sessionOptions); - - if (modelData != null && modelData.length > 0) { - // load model via model data array - ortSession = ortEnvironment.createSession(modelData, sessionOptions); - } else if (uri.startsWith("file://") || uri.startsWith("/")) { - // load model from local - if (uri.startsWith("file://")) { - uri = uri.substring(7); - } - ortSession = ortEnvironment.createSession(uri, sessionOptions); - } else { - // load model via model path string uri - InputStream modelStream = - reactContext.getApplicationContext().getContentResolver().openInputStream(Uri.parse(uri)); - Reader reader = new BufferedReader(new InputStreamReader(modelStream)); - byte[] modelArray = new byte[modelStream.available()]; - modelStream.read(modelArray); - modelStream.close(); - ortSession = ortEnvironment.createSession(modelArray, sessionOptions); - } - - String key = getNextSessionKey(); - sessionMap.put(key, ortSession); - - WritableMap resultMap = Arguments.createMap(); - resultMap.putString("key", key); - WritableArray inputNames = Arguments.createArray(); - for (String inputName : ortSession.getInputNames()) { - inputNames.pushString(inputName); - } - resultMap.putArray("inputNames", inputNames); - WritableArray outputNames = Arguments.createArray(); - for (String outputName : ortSession.getOutputNames()) { - outputNames.pushString(outputName); - } - resultMap.putArray("outputNames", outputNames); - - return resultMap; - } - - /** - * Dispose a model using given key. - * - * @param key a session key representing the session given at loadModel() - */ - public void dispose(String key) throws OrtException { - OrtSession ortSession = sessionMap.get(key); - if (ortSession != null) { - ortSession.close(); - sessionMap.remove(key); - } - } - - /** - * Run a model using given uri. - * - * @param key a session key representing the session given at loadModel() - * @param input an input tensor - * @param output an output names to be returned - * @param options onnxruntime run options - * @return inference result - */ - public WritableMap run(String key, ReadableMap input, ReadableArray output, ReadableMap options) throws Exception { - OrtSession ortSession = sessionMap.get(key); - if (ortSession == null) { - throw new Exception("Model is not loaded."); - } - - RunOptions runOptions = parseRunOptions(options); - - checkBlobModule(); - - long startTime = System.currentTimeMillis(); - Map feed = new HashMap<>(); - Iterator iterator = ortSession.getInputNames().iterator(); - Result result = null; - try { - while (iterator.hasNext()) { - String inputName = iterator.next(); - - ReadableMap inputMap = input.getMap(inputName); - if (inputMap == null) { - throw new Exception("Can't find input: " + inputName); - } - - OnnxTensor onnxTensor = TensorHelper.createInputTensor(blobModule, inputMap, ortEnvironment); - feed.put(inputName, onnxTensor); - } - - Set requestedOutputs = null; - if (output.size() > 0) { - requestedOutputs = new HashSet<>(); - for (int i = 0; i < output.size(); ++i) { - requestedOutputs.add(output.getString(i)); - } - } - - long duration = System.currentTimeMillis() - startTime; - Log.d("Duration", "createInputTensor: " + duration); - - startTime = System.currentTimeMillis(); - if (requestedOutputs != null) { - result = ortSession.run(feed, requestedOutputs, runOptions); - } else { - result = ortSession.run(feed, runOptions); - } - duration = System.currentTimeMillis() - startTime; - Log.d("Duration", "inference: " + duration); - - startTime = System.currentTimeMillis(); - WritableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - duration = System.currentTimeMillis() - startTime; - Log.d("Duration", "createOutputTensor: " + duration); - - return resultMap; - - } finally { - OnnxValue.close(feed); - if (result != null) { - result.close(); - } - } - } - - private static final Map graphOptimizationLevelTable = - Stream - .of(new Object[][] { - {"disabled", SessionOptions.OptLevel.NO_OPT}, - {"basic", SessionOptions.OptLevel.BASIC_OPT}, - {"extended", SessionOptions.OptLevel.EXTENDED_OPT}, - // {"layout", SessionOptions.OptLevel.LAYOUT_OPT}, - {"all", SessionOptions.OptLevel.ALL_OPT}, - }) - .collect(Collectors.toMap(p -> (String)p[0], p -> (SessionOptions.OptLevel)p[1])); - - private static final Map executionModeTable = - Stream - .of(new Object[][] {{"sequential", SessionOptions.ExecutionMode.SEQUENTIAL}, - {"parallel", SessionOptions.ExecutionMode.PARALLEL}}) - .collect(Collectors.toMap(p -> (String)p[0], p -> (SessionOptions.ExecutionMode)p[1])); - - private SessionOptions parseSessionOptions(ReadableMap options) throws OrtException { - SessionOptions sessionOptions = new SessionOptions(); - - if (options.hasKey("intraOpNumThreads")) { - int intraOpNumThreads = options.getInt("intraOpNumThreads"); - if (intraOpNumThreads > 0 && intraOpNumThreads < Integer.MAX_VALUE) { - sessionOptions.setIntraOpNumThreads(intraOpNumThreads); - } - } - - if (options.hasKey("interOpNumThreads")) { - int interOpNumThreads = options.getInt("interOpNumThreads"); - if (interOpNumThreads > 0 && interOpNumThreads < Integer.MAX_VALUE) { - sessionOptions.setInterOpNumThreads(interOpNumThreads); - } + return false; } - - if (options.hasKey("graphOptimizationLevel")) { - String graphOptimizationLevel = options.getString("graphOptimizationLevel"); - if (graphOptimizationLevelTable.containsKey(graphOptimizationLevel)) { - sessionOptions.setOptimizationLevel(graphOptimizationLevelTable.get(graphOptimizationLevel)); - } - } - - if (options.hasKey("enableCpuMemArena")) { - boolean enableCpuMemArena = options.getBoolean("enableCpuMemArena"); - sessionOptions.setCPUArenaAllocator(enableCpuMemArena); - } - - if (options.hasKey("enableMemPattern")) { - boolean enableMemPattern = options.getBoolean("enableMemPattern"); - sessionOptions.setMemoryPatternOptimization(enableMemPattern); - } - - if (options.hasKey("executionMode")) { - String executionMode = options.getString("executionMode"); - if (executionModeTable.containsKey(executionMode)) { - sessionOptions.setExecutionMode(executionModeTable.get(executionMode)); - } - } - - if (options.hasKey("executionProviders")) { - ReadableArray executionProviders = options.getArray("executionProviders"); - for (int i = 0; i < executionProviders.size(); ++i) { - String epName = null; - ReadableMap epOptions = null; - if (executionProviders.getType(i) == ReadableType.String) { - epName = executionProviders.getString(i); - } else { - epOptions = executionProviders.getMap(i); - epName = epOptions.getString("name"); - } - if (epName.equals("nnapi")) { - EnumSet flags = EnumSet.noneOf(NNAPIFlags.class); - if (epOptions != null) { - if (epOptions.hasKey("useFP16") && epOptions.getBoolean("useFP16")) { - flags.add(NNAPIFlags.USE_FP16); - } - if (epOptions.hasKey("useNCHW") && epOptions.getBoolean("useNCHW")) { - flags.add(NNAPIFlags.USE_NCHW); - } - if (epOptions.hasKey("cpuDisabled") && epOptions.getBoolean("cpuDisabled")) { - flags.add(NNAPIFlags.CPU_DISABLED); - } - if (epOptions.hasKey("cpuOnly") && epOptions.getBoolean("cpuOnly")) { - flags.add(NNAPIFlags.CPU_ONLY); - } - } - sessionOptions.addNnapi(flags); - } else if (epName.equals("xnnpack")) { - sessionOptions.addXnnpack(Collections.emptyMap()); - } else if (epName.equals("cpu")) { - continue; - } else { - throw new OrtException("Unsupported execution provider: " + epName); - } - } - } - - if (options.hasKey("logId")) { - String logId = options.getString("logId"); - sessionOptions.setLoggerId(logId); - } - - if (options.hasKey("logSeverityLevel")) { - int logSeverityLevel = options.getInt("logSeverityLevel"); - sessionOptions.setSessionLogLevel(OrtLoggingLevel.mapFromInt(logSeverityLevel)); - } - - return sessionOptions; } - private RunOptions parseRunOptions(ReadableMap options) throws OrtException { - RunOptions runOptions = new RunOptions(); - - if (options.hasKey("logSeverityLevel")) { - int logSeverityLevel = options.getInt("logSeverityLevel"); - runOptions.setLogLevel(OrtLoggingLevel.mapFromInt(logSeverityLevel)); - } - - if (options.hasKey("tag")) { - String tag = options.getString("tag"); - runOptions.setRunTag(tag); - } - - return runOptions; - } - - @Override - public void onHostResume() {} - @Override - public void onHostPause() {} - - @Override - public void onHostDestroy() { - for (String key : sessionMap.keySet()) { - try { - dispose(key); - } catch (Exception e) { - Log.e("onHostDestroy", "Failed to dispose session: " + key, e); - } - } - sessionMap.clear(); + public Map getConstants() { + final Map constants = new HashMap(); + constants.put("ORT_EXTENSIONS_PATH", OnnxruntimeExtensions.getLibraryPath()); + return constants; } } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimePackage.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimePackage.java index bb4386a0953f3..9171641e6e68a 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimePackage.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimePackage.java @@ -22,7 +22,6 @@ public class OnnxruntimePackage implements ReactPackage { public List createNativeModules(@NonNull ReactApplicationContext reactContext) { List modules = new ArrayList<>(); modules.add(new OnnxruntimeModule(reactContext)); - modules.add(new OnnxruntimeJSIHelper(reactContext)); return modules; } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java deleted file mode 100644 index 63cddace36640..0000000000000 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package ai.onnxruntime.reactnative; - -import ai.onnxruntime.OnnxJavaType; -import ai.onnxruntime.OnnxTensor; -import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtSession; -import ai.onnxruntime.OrtUtil; -import ai.onnxruntime.TensorInfo; -import android.util.Base64; -import com.facebook.react.bridge.Arguments; -import com.facebook.react.bridge.ReadableArray; -import com.facebook.react.bridge.ReadableMap; -import com.facebook.react.bridge.WritableArray; -import com.facebook.react.bridge.WritableMap; -import com.facebook.react.modules.blob.BlobModule; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; -import java.nio.LongBuffer; -import java.nio.ShortBuffer; -import java.util.Iterator; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -public class TensorHelper { - /** - * Supported tensor data type - */ - public static final String JsTensorTypeBool = "bool"; - public static final String JsTensorTypeByte = "int8"; - public static final String JsTensorTypeUnsignedByte = "uint8"; - public static final String JsTensorTypeShort = "int16"; - public static final String JsTensorTypeInt = "int32"; - public static final String JsTensorTypeLong = "int64"; - public static final String JsTensorTypeFloat = "float32"; - public static final String JsTensorTypeDouble = "float64"; - public static final String JsTensorTypeString = "string"; - - /** - * It creates an input tensor from a map passed by react native js. - * 'data' is blob object and the buffer is stored in BlobModule. It first resolve it and creates a tensor. - */ - public static OnnxTensor createInputTensor(BlobModule blobModule, ReadableMap inputTensor, - OrtEnvironment ortEnvironment) throws Exception { - // shape - ReadableArray dimsArray = inputTensor.getArray("dims"); - long[] dims = new long[dimsArray.size()]; - for (int i = 0; i < dimsArray.size(); ++i) { - dims[i] = dimsArray.getInt(i); - } - - // type - TensorInfo.OnnxTensorType tensorType = getOnnxTensorType(inputTensor.getString("type")); - - // data - OnnxTensor onnxTensor = null; - if (tensorType == TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - ReadableArray values = inputTensor.getArray("data"); - String[] buffer = new String[values.size()]; - for (int i = 0; i < values.size(); ++i) { - buffer[i] = values.getString(i); - } - onnxTensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - } else { - ReadableMap data = inputTensor.getMap("data"); - String blobId = data.getString("blobId"); - byte[] bytes = blobModule.resolve(blobId, data.getInt("offset"), data.getInt("size")); - blobModule.remove(blobId); - ByteBuffer values = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()); - onnxTensor = createInputTensor(tensorType, dims, values, ortEnvironment); - } - - return onnxTensor; - } - - /** - * It creates an output map from an output tensor. - * a data array is store in BlobModule. - */ - public static WritableMap createOutputTensor(BlobModule blobModule, OrtSession.Result result) throws Exception { - WritableMap outputTensorMap = Arguments.createMap(); - - Iterator> iterator = result.iterator(); - while (iterator.hasNext()) { - Map.Entry entry = iterator.next(); - String outputName = entry.getKey(); - OnnxValue onnxValue = (OnnxValue)entry.getValue(); - if (onnxValue.getType() != OnnxValue.OnnxValueType.ONNX_TYPE_TENSOR) { - throw new Exception("Not supported type: " + onnxValue.getType().toString()); - } - - OnnxTensor onnxTensor = (OnnxTensor)onnxValue; - WritableMap outputTensor = Arguments.createMap(); - - // dims - WritableArray outputDims = Arguments.createArray(); - long[] dims = onnxTensor.getInfo().getShape(); - for (long dim : dims) { - outputDims.pushInt((int)dim); - } - outputTensor.putArray("dims", outputDims); - - // type - outputTensor.putString("type", getJsTensorType(onnxTensor.getInfo().onnxType)); - - // data - if (onnxTensor.getInfo().onnxType == TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - String[] buffer = (String[])onnxTensor.getValue(); - WritableArray dataArray = Arguments.createArray(); - for (String value : buffer) { - dataArray.pushString(value); - } - outputTensor.putArray("data", dataArray); - } else { - // Store in BlobModule then create a blob object as data - byte[] bufferArray = createOutputTensor(onnxTensor); - WritableMap data = Arguments.createMap(); - data.putString("blobId", blobModule.store(bufferArray)); - data.putInt("offset", 0); - data.putInt("size", bufferArray.length); - outputTensor.putMap("data", data); - } - - outputTensorMap.putMap(outputName, outputTensor); - } - - return outputTensorMap; - } - - private static OnnxTensor createInputTensor(TensorInfo.OnnxTensorType tensorType, long[] dims, ByteBuffer values, - OrtEnvironment ortEnvironment) throws Exception { - OnnxTensor tensor = null; - switch (tensorType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - FloatBuffer buffer = values.asFloatBuffer(); - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - ByteBuffer buffer = values; - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { - ShortBuffer buffer = values.asShortBuffer(); - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - IntBuffer buffer = values.asIntBuffer(); - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - LongBuffer buffer = values.asLongBuffer(); - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - DoubleBuffer buffer = values.asDoubleBuffer(); - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - ByteBuffer buffer = values; - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.UINT8); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - ByteBuffer buffer = values; - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.BOOL); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: - default: - throw new IllegalStateException("Unexpected value: " + tensorType.toString()); - } - - return tensor; - } - - private static byte[] createOutputTensor(OnnxTensor onnxTensor) throws Exception { - TensorInfo tensorInfo = onnxTensor.getInfo(); - ByteBuffer buffer = null; - - int capacity = (int)OrtUtil.elementCount(onnxTensor.getInfo().getShape()); - - switch (tensorInfo.onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - buffer = ByteBuffer.allocate(capacity * 4).order(ByteOrder.nativeOrder()); - buffer.asFloatBuffer().put(onnxTensor.getFloatBuffer()); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - buffer = ByteBuffer.allocate(capacity).order(ByteOrder.nativeOrder()); - buffer.put(onnxTensor.getByteBuffer()); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - buffer = ByteBuffer.allocate(capacity * 2).order(ByteOrder.nativeOrder()); - buffer.asShortBuffer().put(onnxTensor.getShortBuffer()); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - buffer = ByteBuffer.allocate(capacity * 4).order(ByteOrder.nativeOrder()); - buffer.asIntBuffer().put(onnxTensor.getIntBuffer()); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - buffer = ByteBuffer.allocate(capacity * 8).order(ByteOrder.nativeOrder()); - buffer.asLongBuffer().put(onnxTensor.getLongBuffer()); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - buffer = ByteBuffer.allocate(capacity * 8).order(ByteOrder.nativeOrder()); - buffer.asDoubleBuffer().put(onnxTensor.getDoubleBuffer()); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - buffer = ByteBuffer.allocate(capacity).order(ByteOrder.nativeOrder()); - buffer.put(onnxTensor.getByteBuffer()); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - default: - throw new IllegalStateException("Unexpected type: " + tensorInfo.onnxType.toString()); - } - - return buffer.array(); - } - - private static final Map JsTensorTypeToOnnxTensorTypeMap = - Stream - .of(new Object[][] { - {JsTensorTypeFloat, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, - {JsTensorTypeByte, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, - {JsTensorTypeUnsignedByte, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, - {JsTensorTypeShort, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, - {JsTensorTypeInt, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, - {JsTensorTypeLong, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, - {JsTensorTypeString, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING}, - {JsTensorTypeBool, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL}, - {JsTensorTypeDouble, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, - }) - .collect(Collectors.toMap(p -> (String)p[0], p -> (TensorInfo.OnnxTensorType)p[1])); - - private static TensorInfo.OnnxTensorType getOnnxTensorType(String type) { - if (JsTensorTypeToOnnxTensorTypeMap.containsKey(type)) { - return JsTensorTypeToOnnxTensorTypeMap.get(type); - } else { - return TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - } - } - - private static final Map OnnxTensorTypeToJsTensorTypeMap = - Stream - .of(new Object[][] { - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, JsTensorTypeFloat}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, JsTensorTypeByte}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, JsTensorTypeUnsignedByte}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, JsTensorTypeShort}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, JsTensorTypeInt}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, JsTensorTypeLong}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, JsTensorTypeString}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, JsTensorTypeBool}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, JsTensorTypeDouble}, - }) - .collect(Collectors.toMap(p -> (TensorInfo.OnnxTensorType)p[0], p -> (String)p[1])); - - private static String getJsTensorType(TensorInfo.OnnxTensorType type) { - if (OnnxTensorTypeToJsTensorTypeMap.containsKey(type)) { - return OnnxTensorTypeToJsTensorTypeMap.get(type); - } else { - return "undefined"; - } - } -} diff --git a/js/react_native/cpp/AsyncWorker.h b/js/react_native/cpp/AsyncWorker.h new file mode 100644 index 0000000000000..ceca5c0ac203e --- /dev/null +++ b/js/react_native/cpp/AsyncWorker.h @@ -0,0 +1,131 @@ +#pragma once + +#include "Env.h" +#include +#include +#include +#include +#include +#include + +using namespace facebook::jsi; + +namespace onnxruntimejsi { + +/** + * @brief AsyncWorker is a helper class to run a function asynchronously and + * return a promise. + * + * @param rt The runtime to use. + * @param env The environment to use. + */ +class AsyncWorker : public HostObject, public std::enable_shared_from_this { + public: + AsyncWorker(Runtime& rt, std::shared_ptr env) : rt_(rt), env_(env), cancel_(false) {} + + ~AsyncWorker() { + if (worker_.joinable()) { + if (worker_.get_id() != std::this_thread::get_id()) { + cancel_ = true; + onAbort(); + worker_.join(); + } else { + worker_.detach(); + } + } + } + + /** + * @brief Make sure the value won't be garbage collected during the async + * operation. + * + * @param rt The runtime to use. + * @param value The value to keep. + */ + void keepValue(Runtime& rt, const Value& value) { + keptValues_.push_back(std::make_shared(rt, value)); + } + + /** + * @brief Create a promise to be used in the async operation. + * + * @param rt The runtime to use. + * @return The promise. + */ + Value toPromise(Runtime& rt) { + auto promiseCtor = rt.global().getPropertyAsFunction(rt, "Promise"); + + auto promise = promiseCtor.callAsConstructor( + rt, Function::createFromHostFunction( + rt, PropNameID::forAscii(rt, "executor"), 2, + [this](Runtime& rt, const Value& thisVal, const Value* args, + size_t count) -> Value { + resolveFunc_ = std::make_shared(rt, args[0]); + rejectFunc_ = std::make_shared(rt, args[1]); + cancel_ = false; + worker_ = std::thread([this]() { + if (cancel_) return; + try { + execute(); + dispatchResolve(); + } catch (const std::exception& e) { + dispatchReject(e.what()); + } + }); + return Value::undefined(); + })); + promise.asObject(rt).setProperty(rt, "__nativeWorker", Object::createFromHostObject(rt, shared_from_this())); + return promise; + } + + protected: + virtual void execute() = 0; + + virtual Value onResolve(Runtime& rt) = 0; + virtual Value onReject(Runtime& rt, const std::string& err) { + return String::createFromUtf8(rt, err); + } + + virtual void onAbort() {} + + private: + void dispatchResolve() { + if (cancel_) return; + auto self = shared_from_this(); + env_->runOnJsThread([self]() { + auto resVal = self->onResolve(self->rt_); + self->resolveFunc_->asObject(self->rt_) + .asFunction(self->rt_) + .call(self->rt_, resVal); + self->clearKeeps(); + }); + } + + void dispatchReject(const std::string& err) { + if (cancel_) return; + auto self = shared_from_this(); + env_->runOnJsThread([self, err]() { + auto resVal = self->onReject(self->rt_, err); + self->rejectFunc_->asObject(self->rt_) + .asFunction(self->rt_) + .call(self->rt_, resVal); + self->clearKeeps(); + }); + } + + void clearKeeps() { + keptValues_.clear(); + resolveFunc_.reset(); + rejectFunc_.reset(); + } + + Runtime& rt_; + std::shared_ptr env_; + std::atomic cancel_; + std::vector> keptValues_; + std::shared_ptr resolveFunc_; + std::shared_ptr rejectFunc_; + std::thread worker_; +}; + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/Env.h b/js/react_native/cpp/Env.h new file mode 100644 index 0000000000000..9e7b7651a971f --- /dev/null +++ b/js/react_native/cpp/Env.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include +#include +#include +#include "onnxruntime_cxx_api.h" +#include + +namespace onnxruntimejsi { + +class Env : public std::enable_shared_from_this { + public: + Env(std::shared_ptr jsInvoker) + : jsInvoker_(jsInvoker) {} + + ~Env() {} + + inline void initOrtEnv(OrtLoggingLevel logLevel, const char* logid) { + if (ortEnv_) { + return; + } + ortEnv_ = std::make_shared(logLevel, logid); + } + + inline void setTensorConstructor( + std::shared_ptr tensorConstructor) { + tensorConstructor_ = tensorConstructor; + } + + inline facebook::jsi::Value + getTensorConstructor(facebook::jsi::Runtime& runtime) const { + return tensorConstructor_->lock(runtime); + } + + inline Ort::Env& getOrtEnv() const { return *ortEnv_; } + + inline void runOnJsThread(std::function&& func) { + if (!jsInvoker_) return; + jsInvoker_->invokeAsync(std::move(func)); + } + + private: + std::shared_ptr jsInvoker_; + std::shared_ptr tensorConstructor_; + std::shared_ptr ortEnv_; +}; + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/InferenceSessionHostObject.cpp b/js/react_native/cpp/InferenceSessionHostObject.cpp new file mode 100644 index 0000000000000..c8efd49e6d669 --- /dev/null +++ b/js/react_native/cpp/InferenceSessionHostObject.cpp @@ -0,0 +1,312 @@ +#include "InferenceSessionHostObject.h" +#include "AsyncWorker.h" +#include "JsiUtils.h" +#include "SessionUtils.h" +#include "TensorUtils.h" +#include + +using namespace facebook::jsi; + +namespace onnxruntimejsi { + +class InferenceSessionHostObject::LoadModelAsyncWorker : public AsyncWorker { + public: + LoadModelAsyncWorker(Runtime& runtime, const Value* arguments, size_t count, + std::shared_ptr session) + : AsyncWorker(runtime, session->env_), session_(session) { + if (count < 1) + throw JSError(runtime, "loadModel requires at least 1 argument"); + if (arguments[0].isString()) { + modelPath_ = arguments[0].asString(runtime).utf8(runtime); + if (modelPath_.find("file:/") == 0) { + modelPath_ = modelPath_.substr(5); + if (modelPath_.find("//") == 0) { + modelPath_ = modelPath_.substr(2); + } + } + } else if (arguments[0].isObject() && + arguments[0].asObject(runtime).isArrayBuffer(runtime)) { + auto arrayBufferObj = arguments[0].asObject(runtime); + auto arrayBuffer = arrayBufferObj.getArrayBuffer(runtime); + modelData_ = arrayBuffer.data(runtime); + modelDataLength_ = arrayBuffer.size(runtime); + } else { + throw JSError(runtime, "Model path or buffer is required"); + } + keepValue(runtime, arguments[0]); + if (count > 1) { + parseSessionOptions(runtime, arguments[1], sessionOptions_); + } + } + + protected: + void execute() { + if (modelPath_.empty()) { + session_->session_ = std::make_shared( + session_->env_->getOrtEnv(), modelData_, modelDataLength_, + sessionOptions_); + } else { + session_->session_ = std::make_shared( + session_->env_->getOrtEnv(), modelPath_.c_str(), sessionOptions_); + } + } + + Value onResolve(Runtime& rt) { return Value::undefined(); } + + private: + std::string error_; + std::string modelPath_; + void* modelData_; + size_t modelDataLength_; + std::shared_ptr session_; + Ort::SessionOptions sessionOptions_; + std::shared_ptr weakResolve_; + std::shared_ptr weakReject_; + std::thread thread_; +}; + +DEFINE_METHOD(InferenceSessionHostObject::loadModel) { + auto self = shared_from_this(); + auto worker = + std::make_shared(runtime, arguments, count, self); + return worker->toPromise(runtime); +} + +class InferenceSessionHostObject::RunAsyncWorker : public AsyncWorker { + public: + RunAsyncWorker(Runtime& runtime, const Value* arguments, size_t count, + std::shared_ptr session) + : AsyncWorker(runtime, session->env_), + env_(session->env_), + session_(session->session_), + memoryInfo_(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault)) { + if (count < 1) + throw JSError(runtime, "run requires at least 1 argument"); + if (count > 2 && !arguments[2].isUndefined()) { + parseRunOptions(runtime, arguments[2], runOptions_); + } + forEach(runtime, arguments[0].asObject(runtime), + [&](const std::string& key, const Value& value, size_t index) { + inputNames_.push_back(key); + inputValues_.push_back(TensorUtils::createOrtValueFromJSTensor( + runtime, value.asObject(runtime), memoryInfo_)); + keepValue(runtime, value); + }); + forEach(runtime, arguments[1].asObject(runtime), + [&](const std::string& key, const Value& value, size_t index) { + outputNames_.push_back(key); + if (value.isObject() && + TensorUtils::isTensor(runtime, value.asObject(runtime))) { + outputValues_.push_back(TensorUtils::createOrtValueFromJSTensor( + runtime, value.asObject(runtime), memoryInfo_)); + jsOutputValues_.push_back(std::make_shared( + runtime, value.asObject(runtime))); + keepValue(runtime, value); + } else { + outputValues_.push_back(Ort::Value()); + jsOutputValues_.push_back(nullptr); + } + }); + } + + protected: + void execute() { + auto inputNames = std::vector(inputNames_.size()); + std::transform(inputNames_.begin(), inputNames_.end(), inputNames.begin(), + [](const std::string& name) { return name.c_str(); }); + auto outputNames = std::vector(outputNames_.size()); + std::transform(outputNames_.begin(), outputNames_.end(), + outputNames.begin(), + [](const std::string& name) { return name.c_str(); }); + auto session = session_.lock(); + if (!session) { + throw std::runtime_error("Session is released"); + } + session->Run(runOptions_, inputNames.data(), inputValues_.data(), + inputValues_.size(), outputNames.data(), + outputValues_.data(), outputValues_.size()); + } + + Value onResolve(Runtime& rt) { + auto resultObject = Object(rt); + auto tensorConstructor = + env_->getTensorConstructor(rt).asObject(rt); + for (size_t i = 0; i < outputValues_.size(); ++i) { + if (jsOutputValues_[i] != nullptr && outputValues_[i].IsTensor()) { + resultObject.setProperty(rt, outputNames_[i].c_str(), + jsOutputValues_[i]->lock(rt)); + } else { + auto tensorObj = TensorUtils::createJSTensorFromOrtValue( + rt, outputValues_[i], tensorConstructor); + resultObject.setProperty(rt, outputNames_[i].c_str(), + Value(rt, tensorObj)); + } + } + return Value(rt, resultObject); + } + + void onAbort() { + runOptions_.SetTerminate(); + } + + private: + std::shared_ptr env_; + std::weak_ptr session_; + Ort::MemoryInfo memoryInfo_; + Ort::RunOptions runOptions_; + std::vector inputNames_; + std::vector inputValues_; + std::vector outputNames_; + std::vector outputValues_; + std::vector> jsOutputValues_; +}; + +DEFINE_METHOD(InferenceSessionHostObject::run) { + auto self = shared_from_this(); + auto worker = + std::make_shared(runtime, arguments, count, self); + return worker->toPromise(runtime); +} + +DEFINE_METHOD(InferenceSessionHostObject::dispose) { + session_.reset(); + return Value::undefined(); +} + +DEFINE_METHOD(InferenceSessionHostObject::endProfiling) { + try { + Ort::AllocatorWithDefaultOptions allocator; + auto filename = session_->EndProfilingAllocated(allocator); + return String::createFromUtf8(runtime, std::string(filename.get())); + } catch (const std::exception& e) { + throw JSError(runtime, std::string(e.what())); + } +} + +DEFINE_GETTER(InferenceSessionHostObject::inputMetadata) { + if (!session_) { + return Array(runtime, 0); + } + try { + Ort::AllocatorWithDefaultOptions allocator; + size_t numInputs = session_->GetInputCount(); + auto array = Array(runtime, numInputs); + + for (size_t i = 0; i < numInputs; i++) { + auto item = Object(runtime); + auto inputName = session_->GetInputNameAllocated(i, allocator); + item.setProperty( + runtime, "name", + String::createFromUtf8(runtime, std::string(inputName.get()))); + + try { + auto typeInfo = session_->GetInputTypeInfo(i); + auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo(); + + // Get data type + auto dataType = tensorInfo.GetElementType(); + item.setProperty(runtime, "type", static_cast(dataType)); + + // Get shape + auto shape = tensorInfo.GetShape(); + auto shapeArray = Array(runtime, shape.size()); + for (size_t j = 0; j < shape.size(); j++) { + shapeArray.setValueAtIndex(runtime, j, + Value(static_cast(shape[j]))); + } + item.setProperty(runtime, "shape", shapeArray); + + item.setProperty(runtime, "isTensor", Value(true)); + + // symbolicDimensions + auto symbolicDimensions = tensorInfo.GetSymbolicDimensions(); + auto symbolicDimensionsArray = + Array(runtime, symbolicDimensions.size()); + for (size_t j = 0; j < symbolicDimensions.size(); j++) { + symbolicDimensionsArray.setValueAtIndex( + runtime, j, + String::createFromUtf8(runtime, symbolicDimensions[j])); + } + item.setProperty(runtime, "symbolicDimensions", + symbolicDimensionsArray); + } catch (const std::exception&) { + // Fallback for unknown types + item.setProperty(runtime, "type", + String::createFromUtf8(runtime, "unknown")); + item.setProperty(runtime, "shape", Array(runtime, 0)); + item.setProperty(runtime, "isTensor", Value(false)); + } + + array.setValueAtIndex(runtime, i, Value(runtime, item)); + } + + return Value(runtime, array); + } catch (const Ort::Exception& e) { + throw JSError(runtime, std::string(e.what())); + } +} + +DEFINE_GETTER(InferenceSessionHostObject::outputMetadata) { + if (!session_) { + return Array(runtime, 0); + } + try { + Ort::AllocatorWithDefaultOptions allocator; + size_t numOutputs = session_->GetOutputCount(); + auto array = Array(runtime, numOutputs); + + for (size_t i = 0; i < numOutputs; i++) { + auto item = Object(runtime); + auto outputName = session_->GetOutputNameAllocated(i, allocator); + item.setProperty( + runtime, "name", + String::createFromUtf8(runtime, std::string(outputName.get()))); + + try { + auto typeInfo = session_->GetOutputTypeInfo(i); + auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo(); + + // Get data type + auto dataType = tensorInfo.GetElementType(); + item.setProperty(runtime, "type", static_cast(dataType)); + + // Get shape + auto shape = tensorInfo.GetShape(); + auto shapeArray = Array(runtime, shape.size()); + for (size_t j = 0; j < shape.size(); j++) { + shapeArray.setValueAtIndex(runtime, j, + Value(static_cast(shape[j]))); + } + item.setProperty(runtime, "shape", shapeArray); + + item.setProperty(runtime, "isTensor", Value(true)); + + // symbolicDimensions + auto symbolicDimensions = tensorInfo.GetSymbolicDimensions(); + auto symbolicDimensionsArray = + Array(runtime, symbolicDimensions.size()); + for (size_t j = 0; j < symbolicDimensions.size(); j++) { + symbolicDimensionsArray.setValueAtIndex( + runtime, j, + String::createFromUtf8(runtime, symbolicDimensions[j])); + } + item.setProperty(runtime, "symbolicDimensions", + symbolicDimensionsArray); + } catch (const std::exception&) { + // Fallback for unknown types + item.setProperty(runtime, "type", + String::createFromUtf8(runtime, "unknown")); + item.setProperty(runtime, "shape", Array(runtime, 0)); + item.setProperty(runtime, "isTensor", Value(false)); + } + + array.setValueAtIndex(runtime, i, Value(runtime, item)); + } + + return Value(runtime, array); + } catch (const Ort::Exception& e) { + throw JSError(runtime, std::string(e.what())); + } +} + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/InferenceSessionHostObject.h b/js/react_native/cpp/InferenceSessionHostObject.h new file mode 100644 index 0000000000000..f13a8d46d4048 --- /dev/null +++ b/js/react_native/cpp/InferenceSessionHostObject.h @@ -0,0 +1,55 @@ +#pragma once + +#include "Env.h" +#include "JsiHelper.h" +#include +#include +#include "onnxruntime_cxx_api.h" +#include + +using namespace facebook::jsi; + +namespace onnxruntimejsi { + +class InferenceSessionHostObject + : public HostObjectHelper, + public std::enable_shared_from_this { + public: + InferenceSessionHostObject(std::shared_ptr env) : HostObjectHelper({ + METHOD_INFO(InferenceSessionHostObject, loadModel, 2), + METHOD_INFO(InferenceSessionHostObject, run, 2), + METHOD_INFO(InferenceSessionHostObject, dispose, 0), + METHOD_INFO(InferenceSessionHostObject, endProfiling, 0), + }, + { + GETTER_INFO(InferenceSessionHostObject, inputMetadata), + GETTER_INFO(InferenceSessionHostObject, outputMetadata), + }), + env_(env) {} + + static inline facebook::jsi::Value + constructor(std::shared_ptr env, facebook::jsi::Runtime& runtime, + const facebook::jsi::Value& thisValue, + const facebook::jsi::Value* arguments, size_t count) { + return facebook::jsi::Object::createFromHostObject( + runtime, std::make_shared(env)); + } + + protected: + class LoadModelAsyncWorker; + class RunAsyncWorker; + + private: + std::shared_ptr env_; + std::shared_ptr session_; + + DEFINE_METHOD(loadModel); + DEFINE_METHOD(run); + DEFINE_METHOD(dispose); + DEFINE_METHOD(endProfiling); + + DEFINE_GETTER(inputMetadata); + DEFINE_GETTER(outputMetadata); +}; + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/JsiHelper.h b/js/react_native/cpp/JsiHelper.h new file mode 100644 index 0000000000000..953429e2c26d5 --- /dev/null +++ b/js/react_native/cpp/JsiHelper.h @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include +#include +#include + +#define BIND_METHOD(method) \ + std::bind(&method, std::placeholders::_1, std::placeholders::_2, \ + std::placeholders::_3, std::placeholders::_4) + +#define BIND_GETTER(method) std::bind(&method, std::placeholders::_1) + +#define BIND_SETTER(method) \ + std::bind(&method, std::placeholders::_1, std::placeholders::_2) + +#define BIND_THIS_METHOD(cls, name) \ + std::bind(&cls::name##_method, this, std::placeholders::_1, \ + std::placeholders::_2, std::placeholders::_3, \ + std::placeholders::_4) + +#define BIND_THIS_GETTER(cls, name) \ + std::bind(&cls::name##_get, this, std::placeholders::_1) + +#define BIND_THIS_SETTER(cls, name) \ + std::bind(&cls::name##_set, this, std::placeholders::_1, \ + std::placeholders::_2) + +#define METHOD_INFO(cls, name, count) \ + { \ + #name, { BIND_THIS_METHOD(cls, name), count } \ + } + +#define GETTER_INFO(cls, name) \ + {#name, BIND_THIS_GETTER(cls, name)} + +#define DEFINE_METHOD(name) \ + Value name##_method(Runtime& runtime, const Value& thisValue, \ + const Value* arguments, size_t count) + +#define DEFINE_GETTER(name) Value name##_get(Runtime& runtime) + +#define DEFINE_SETTER(name) \ + void name##_set(Runtime& runtime, const Value& value) + +typedef std::function + JsiMethod; +typedef std::function + JsiGetter; +typedef std::function + JsiSetter; + +struct JsiMethodInfo { + JsiMethod method; + size_t count; +}; + +typedef std::unordered_map JsiMethodMap; +typedef std::unordered_map JsiGetterMap; +typedef std::unordered_map JsiSetterMap; + +class HostObjectHelper : public facebook::jsi::HostObject { + public: + HostObjectHelper( + JsiMethodMap methods = {}, + JsiGetterMap getters = {}, + JsiSetterMap setters = {}) + : methods_(methods), + getters_(getters), + setters_(setters) {} + + std::vector + getPropertyNames(facebook::jsi::Runtime& runtime) override { + std::vector names; + for (auto& [name, _] : methods_) { + names.push_back(facebook::jsi::PropNameID::forUtf8(runtime, name)); + } + for (auto& [name, _] : getters_) { + names.push_back(facebook::jsi::PropNameID::forUtf8(runtime, name)); + } + return names; + } + + facebook::jsi::Value get(facebook::jsi::Runtime& runtime, + const facebook::jsi::PropNameID& name) override { + auto method = methods_.find(name.utf8(runtime)); + if (method != methods_.end()) { + return facebook::jsi::Function::createFromHostFunction(runtime, name, method->second.count, + method->second.method); + } + + auto getter = getters_.find(name.utf8(runtime)); + if (getter != getters_.end()) { + return getter->second(runtime); + } + + return facebook::jsi::Value::undefined(); + } + + void set(facebook::jsi::Runtime& runtime, const facebook::jsi::PropNameID& name, + const facebook::jsi::Value& value) override { + auto setter = setters_.find(name.utf8(runtime)); + if (setter != setters_.end()) { + setter->second(runtime, value); + } + } + + private: + JsiMethodMap methods_; + JsiGetterMap getters_; + JsiSetterMap setters_; +}; diff --git a/js/react_native/cpp/JsiMain.cpp b/js/react_native/cpp/JsiMain.cpp new file mode 100644 index 0000000000000..26e8842b32793 --- /dev/null +++ b/js/react_native/cpp/JsiMain.cpp @@ -0,0 +1,98 @@ +#include "JsiMain.h" +#include "InferenceSessionHostObject.h" +#include "JsiHelper.h" +#include "SessionUtils.h" +#include + +using namespace facebook::jsi; + +namespace onnxruntimejsi { + +std::shared_ptr +install(Runtime& runtime, + std::shared_ptr jsInvoker) { + auto env = std::make_shared(jsInvoker); + try { + auto ortApi = Object(runtime); + + auto initOrtOnceMethod = Function::createFromHostFunction( + runtime, PropNameID::forAscii(runtime, "initOrtOnce"), 2, + [env](Runtime& runtime, const Value& thisValue, const Value* arguments, + size_t count) -> Value { + try { + OrtLoggingLevel logLevel = ORT_LOGGING_LEVEL_WARNING; + if (count > 0 && arguments[0].isNumber()) { + int level = static_cast(arguments[0].asNumber()); + switch (level) { + case 0: + logLevel = ORT_LOGGING_LEVEL_VERBOSE; + break; + case 1: + logLevel = ORT_LOGGING_LEVEL_INFO; + break; + case 2: + logLevel = ORT_LOGGING_LEVEL_WARNING; + break; + case 3: + logLevel = ORT_LOGGING_LEVEL_ERROR; + break; + case 4: + logLevel = ORT_LOGGING_LEVEL_FATAL; + break; + default: + logLevel = ORT_LOGGING_LEVEL_WARNING; + break; + } + } + env->setTensorConstructor(std::make_shared( + runtime, arguments[1].asObject(runtime))); + env->initOrtEnv(logLevel, "onnxruntime-react-native-jsi"); + return Value::undefined(); + } catch (const std::exception& e) { + throw JSError(runtime, "Failed to initialize ONNX Runtime: " + + std::string(e.what())); + } + }); + + ortApi.setProperty(runtime, "initOrtOnce", initOrtOnceMethod); + + auto createInferenceSessionMethod = Function::createFromHostFunction( + runtime, PropNameID::forAscii(runtime, "createInferenceSession"), 0, + std::bind(InferenceSessionHostObject::constructor, env, + std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + ortApi.setProperty(runtime, "createInferenceSession", + createInferenceSessionMethod); + + auto listSupportedBackendsMethod = Function::createFromHostFunction( + runtime, PropNameID::forAscii(runtime, "listSupportedBackends"), 0, + [](Runtime& runtime, const Value& thisValue, const Value* arguments, + size_t count) -> Value { + auto backends = Array(runtime, supportedBackends.size()); + for (size_t i = 0; i < supportedBackends.size(); i++) { + auto backend = Object(runtime); + backend.setProperty( + runtime, "name", + String::createFromUtf8(runtime, supportedBackends[i])); + backends.setValueAtIndex(runtime, i, Value(runtime, backend)); + } + return Value(runtime, backends); + }); + + ortApi.setProperty(runtime, "listSupportedBackends", + listSupportedBackendsMethod); + + ortApi.setProperty( + runtime, "version", + String::createFromUtf8(runtime, OrtGetApiBase()->GetVersionString())); + + runtime.global().setProperty(runtime, "OrtApi", ortApi); + } catch (const std::exception& e) { + throw JSError(runtime, "Failed to install ONNX Runtime JSI bindings: " + + std::string(e.what())); + } + + return env; +} + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/JsiMain.h b/js/react_native/cpp/JsiMain.h new file mode 100644 index 0000000000000..15c8da084c746 --- /dev/null +++ b/js/react_native/cpp/JsiMain.h @@ -0,0 +1,13 @@ +#pragma once + +#include "Env.h" +#include +#include + +namespace onnxruntimejsi { + +std::shared_ptr +install(facebook::jsi::Runtime& runtime, + std::shared_ptr jsInvoker = nullptr); + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/JsiUtils.cpp b/js/react_native/cpp/JsiUtils.cpp new file mode 100644 index 0000000000000..a5c802f36ef30 --- /dev/null +++ b/js/react_native/cpp/JsiUtils.cpp @@ -0,0 +1,32 @@ +#include "JsiUtils.h" + +using namespace facebook::jsi; + +bool isTypedArray(Runtime& runtime, const Object& jsObj) { + if (!jsObj.hasProperty(runtime, "buffer")) + return false; + if (!jsObj.getProperty(runtime, "buffer") + .asObject(runtime) + .isArrayBuffer(runtime)) + return false; + return true; +} + +void forEach(Runtime& runtime, const Object& object, + const std::function& callback) { + auto names = object.getPropertyNames(runtime); + for (size_t i = 0; i < names.size(runtime); i++) { + auto key = + names.getValueAtIndex(runtime, i).asString(runtime).utf8(runtime); + auto value = object.getProperty(runtime, key.c_str()); + callback(key, value, i); + } +} + +void forEach(Runtime& runtime, const Array& array, + const std::function& callback) { + for (size_t i = 0; i < array.size(runtime); i++) { + callback(array.getValueAtIndex(runtime, i), i); + } +} diff --git a/js/react_native/cpp/JsiUtils.h b/js/react_native/cpp/JsiUtils.h new file mode 100644 index 0000000000000..be38d67868df4 --- /dev/null +++ b/js/react_native/cpp/JsiUtils.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include + +bool isTypedArray(facebook::jsi::Runtime& runtime, + const facebook::jsi::Object& jsObj); + +void forEach( + facebook::jsi::Runtime& runtime, const facebook::jsi::Object& object, + const std::function& callback); + +void forEach( + facebook::jsi::Runtime& runtime, const facebook::jsi::Array& array, + const std::function& callback); diff --git a/js/react_native/cpp/SessionUtils.cpp b/js/react_native/cpp/SessionUtils.cpp new file mode 100644 index 0000000000000..3e6672ec546fb --- /dev/null +++ b/js/react_native/cpp/SessionUtils.cpp @@ -0,0 +1,450 @@ +#include "SessionUtils.h" +#include "JsiUtils.h" +#include "cpu_provider_factory.h" +#include +#include "onnxruntime_cxx_api.h" +#ifdef USE_NNAPI +#include "nnapi_provider_factory.h" +#endif +#ifdef USE_COREML +#include "coreml_provider_factory.h" +#endif + +// Note: Using below syntax for including ort c api and ort extensions headers to resolve a compiling error happened +// in an expo react native ios app when ort extensions enabled (a redefinition error of multiple object types defined +// within ORT C API header). It's an edge case that compiler allows both ort c api headers to be included when #include +// syntax doesn't match. For the case when extensions not enabled, it still requires a onnxruntime prefix directory for +// searching paths. Also in general, it's a convention to use #include for C/C++ headers rather then #import. See: +// https://google.github.io/styleguide/objcguide.html#import-and-include +// https://microsoft.github.io/objc-guide/Headers/ImportAndInclude.html +#if defined(ORT_ENABLE_EXTENSIONS) && defined(__APPLE__) +#include +#endif + +using namespace facebook::jsi; + +namespace onnxruntimejsi { + +const std::vector supportedBackends = { + "cpu", + "xnnpack", +#ifdef USE_COREML + "coreml", +#endif +#ifdef USE_NNAPI + "nnapi", +#endif +#ifdef USE_QNN + "qnn", +#endif +}; + +class ExtendedSessionOptions : public Ort::SessionOptions { + public: + ExtendedSessionOptions() = default; + + void AppendExecutionProvider_CPU(int use_arena) { + Ort::ThrowOnError( + OrtSessionOptionsAppendExecutionProvider_CPU(this->p_, use_arena)); + } + + void AddFreeDimensionOverrideByName(const char* name, int64_t value) { + Ort::ThrowOnError( + Ort::GetApi().AddFreeDimensionOverrideByName(this->p_, name, value)); + } +#ifdef USE_NNAPI + void AppendExecutionProvider_Nnapi(uint32_t nnapi_flags) { + Ort::ThrowOnError( + OrtSessionOptionsAppendExecutionProvider_Nnapi(this->p_, nnapi_flags)); + } +#endif +#ifdef USE_COREML + void AppendExecutionProvider_CoreML(int flags) { + Ort::ThrowOnError( + OrtSessionOptionsAppendExecutionProvider_CoreML(this->p_, flags)); + } +#endif +}; + +void parseSessionOptions(Runtime& runtime, const Value& optionsValue, + Ort::SessionOptions& sessionOptions) { + if (!optionsValue.isObject()) + return; + + auto options = optionsValue.asObject(runtime); + + try { +#ifdef ORT_ENABLE_EXTENSIONS + // ortExtLibPath + if (options.hasProperty(runtime, "ortExtLibPath")) { +#ifdef __APPLE__ + Ort::ThrowOnError(RegisterCustomOps(sessionOptions, OrtGetApiBase())); +#endif +#ifdef __ANDROID__ + auto prop = options.getProperty(runtime, "ortExtLibPath"); + if (prop.isString()) { + std::string libraryPath = prop.asString(runtime).utf8(runtime); + sessionOptions.RegisterCustomOpsLibrary(libraryPath.c_str()); + } +#endif + } +#endif + + // intraOpNumThreads + if (options.hasProperty(runtime, "intraOpNumThreads")) { + auto prop = options.getProperty(runtime, "intraOpNumThreads"); + if (prop.isNumber()) { + int numThreads = static_cast(prop.asNumber()); + if (numThreads > 0) { + sessionOptions.SetIntraOpNumThreads(numThreads); + } + } + } + + // interOpNumThreads + if (options.hasProperty(runtime, "interOpNumThreads")) { + auto prop = options.getProperty(runtime, "interOpNumThreads"); + if (prop.isNumber()) { + int numThreads = static_cast(prop.asNumber()); + if (numThreads > 0) { + sessionOptions.SetInterOpNumThreads(numThreads); + } + } + } + + // freeDimensionOverrides + if (options.hasProperty(runtime, "freeDimensionOverrides")) { + auto prop = options.getProperty(runtime, "freeDimensionOverrides"); + if (prop.isObject()) { + auto overrides = prop.asObject(runtime); + forEach(runtime, overrides, + [&](const std::string& key, const Value& value, size_t index) { + reinterpret_cast(sessionOptions) + .AddFreeDimensionOverrideByName( + key.c_str(), static_cast(value.asNumber())); + }); + } + } + + // graphOptimizationLevel + if (options.hasProperty(runtime, "graphOptimizationLevel")) { + auto prop = options.getProperty(runtime, "graphOptimizationLevel"); + if (prop.isString()) { + std::string level = prop.asString(runtime).utf8(runtime); + if (level == "disabled") { + sessionOptions.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + } else if (level == "basic") { + sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); + } else if (level == "extended") { + sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED); + } else if (level == "all") { + sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL); + } + } + } + + // enableCpuMemArena + if (options.hasProperty(runtime, "enableCpuMemArena")) { + auto prop = options.getProperty(runtime, "enableCpuMemArena"); + if (prop.isBool()) { + if (prop.asBool()) { + sessionOptions.EnableCpuMemArena(); + } else { + sessionOptions.DisableCpuMemArena(); + } + } + } + + // enableMemPattern + if (options.hasProperty(runtime, "enableMemPattern")) { + auto prop = options.getProperty(runtime, "enableMemPattern"); + if (prop.isBool()) { + if (prop.asBool()) { + sessionOptions.EnableMemPattern(); + } else { + sessionOptions.DisableMemPattern(); + } + } + } + + // executionMode + if (options.hasProperty(runtime, "executionMode")) { + auto prop = options.getProperty(runtime, "executionMode"); + if (prop.isString()) { + std::string mode = prop.asString(runtime).utf8(runtime); + if (mode == "sequential") { + sessionOptions.SetExecutionMode(ORT_SEQUENTIAL); + } else if (mode == "parallel") { + sessionOptions.SetExecutionMode(ORT_PARALLEL); + } + } + } + + // optimizedModelFilePath + if (options.hasProperty(runtime, "optimizedModelFilePath")) { + auto prop = options.getProperty(runtime, "optimizedModelFilePath"); + if (prop.isString()) { + std::string path = prop.asString(runtime).utf8(runtime); + sessionOptions.SetOptimizedModelFilePath(path.c_str()); + } + } + + // enableProfiling + if (options.hasProperty(runtime, "enableProfiling")) { + auto prop = options.getProperty(runtime, "enableProfiling"); + if (prop.isBool() && prop.asBool()) { + sessionOptions.EnableProfiling("onnxruntime_profile_"); + } + } + + // profileFilePrefix + if (options.hasProperty(runtime, "profileFilePrefix")) { + auto enableProfilingProp = + options.getProperty(runtime, "enableProfiling"); + if (enableProfilingProp.isBool() && enableProfilingProp.asBool()) { + auto prop = options.getProperty(runtime, "profileFilePrefix"); + if (prop.isString()) { + std::string prefix = prop.asString(runtime).utf8(runtime); + sessionOptions.EnableProfiling(prefix.c_str()); + } + } + } + + // logId + if (options.hasProperty(runtime, "logId")) { + auto prop = options.getProperty(runtime, "logId"); + if (prop.isString()) { + std::string logId = prop.asString(runtime).utf8(runtime); + sessionOptions.SetLogId(logId.c_str()); + } + } + + // logSeverityLevel + if (options.hasProperty(runtime, "logSeverityLevel")) { + auto prop = options.getProperty(runtime, "logSeverityLevel"); + if (prop.isNumber()) { + int level = static_cast(prop.asNumber()); + if (level >= 0 && level <= 4) { + sessionOptions.SetLogSeverityLevel(level); + } + } + } + + // externalData + if (options.hasProperty(runtime, "externalData")) { + auto prop = + options.getProperty(runtime, "externalData").asObject(runtime); + if (prop.isArray(runtime)) { + auto externalDataArray = prop.asArray(runtime); + std::vector paths; + std::vector buffs; + std::vector sizes; + forEach( + runtime, externalDataArray, [&](const Value& value, size_t index) { + if (value.isObject()) { + auto externalDataObject = value.asObject(runtime); + if (externalDataObject.hasProperty(runtime, "path")) { + auto pathValue = + externalDataObject.getProperty(runtime, "path"); + if (pathValue.isString()) { + paths.push_back(pathValue.asString(runtime).utf8(runtime)); + } + } + if (externalDataObject.hasProperty(runtime, "data")) { + auto dataValue = + externalDataObject.getProperty(runtime, "data") + .asObject(runtime); + if (isTypedArray(runtime, dataValue)) { + auto arrayBuffer = dataValue.getProperty(runtime, "buffer") + .asObject(runtime) + .getArrayBuffer(runtime); + buffs.push_back( + reinterpret_cast(arrayBuffer.data(runtime))); + sizes.push_back(arrayBuffer.size(runtime)); + } + } + } + }); + sessionOptions.AddExternalInitializersFromFilesInMemory(paths, buffs, + sizes); + } + } + + // executionProviders + if (options.hasProperty(runtime, "executionProviders")) { + auto prop = options.getProperty(runtime, "executionProviders"); + if (prop.isObject() && prop.asObject(runtime).isArray(runtime)) { + auto providers = prop.asObject(runtime).asArray(runtime); + forEach(runtime, providers, [&](const Value& epValue, size_t index) { + std::string epName; + std::unique_ptr providerObj; + if (epValue.isString()) { + epName = epValue.asString(runtime).utf8(runtime); + } else if (epValue.isObject()) { + providerObj = std::make_unique(epValue.asObject(runtime)); + epName = providerObj->getProperty(runtime, "name") + .asString(runtime) + .utf8(runtime); + } + + // Apply execution providers + if (epName == "cpu") { + int use_arena = 0; + if (providerObj && providerObj->hasProperty(runtime, "useArena")) { + auto useArena = providerObj->getProperty(runtime, "useArena"); + if (useArena.isBool() && useArena.asBool()) { + use_arena = 1; + } + } + reinterpret_cast(sessionOptions) + .AppendExecutionProvider_CPU(use_arena); + } else if (epName == "xnnpack") { + sessionOptions.AppendExecutionProvider("XNNPACK"); + } +#ifdef USE_COREML + else if (epName == "coreml") { + int flags = 0; + if (providerObj && + providerObj->hasProperty(runtime, "coreMlFlags")) { + auto flagsValue = + providerObj->getProperty(runtime, "coreMlFlags"); + if (flagsValue.isNumber()) { + flags = static_cast(flagsValue.asNumber()); + } + } + reinterpret_cast(sessionOptions) + .AppendExecutionProvider_CoreML(flags); + } +#endif +#ifdef USE_NNAPI + else if (epName == "nnapi") { + uint32_t nnapi_flags = 0; + if (providerObj && providerObj->hasProperty(runtime, "useFP16")) { + auto useFP16 = providerObj->getProperty(runtime, "useFP16"); + if (useFP16.isBool() && useFP16.asBool()) { + nnapi_flags |= NNAPI_FLAG_USE_FP16; + } + } + if (providerObj && providerObj->hasProperty(runtime, "useNCHW")) { + auto useNCHW = providerObj->getProperty(runtime, "useNCHW"); + if (useNCHW.isBool() && useNCHW.asBool()) { + nnapi_flags |= NNAPI_FLAG_USE_NCHW; + } + } + if (providerObj && + providerObj->hasProperty(runtime, "cpuDisabled")) { + auto cpuDisabled = + providerObj->getProperty(runtime, "cpuDisabled"); + if (cpuDisabled.isBool() && cpuDisabled.asBool()) { + nnapi_flags |= NNAPI_FLAG_CPU_DISABLED; + } + } + if (providerObj && providerObj->hasProperty(runtime, "cpuOnly")) { + auto cpuOnly = providerObj->getProperty(runtime, "cpuOnly"); + if (cpuOnly.isBool() && cpuOnly.asBool()) { + nnapi_flags |= NNAPI_FLAG_CPU_ONLY; + } + } + reinterpret_cast(sessionOptions) + .AppendExecutionProvider_Nnapi(nnapi_flags); + } +#endif +#ifdef USE_QNN + else if (epName == "qnn") { + std::unordered_map options; + if (providerObj && + providerObj->hasProperty(runtime, "backendType")) { + options["backendType"] = + providerObj->getProperty(runtime, "backendType") + .asString(runtime) + .utf8(runtime); + } + if (providerObj && + providerObj->hasProperty(runtime, "backendPath")) { + options["backendPath"] = + providerObj->getProperty(runtime, "backendPath") + .asString(runtime) + .utf8(runtime); + } + if (providerObj && + providerObj->hasProperty(runtime, "enableFp16Precision")) { + auto enableFp16Precision = + providerObj->getProperty(runtime, "enableFp16Precision"); + if (enableFp16Precision.isBool() && + enableFp16Precision.asBool()) { + options["enableFp16Precision"] = "1"; + } else { + options["enableFp16Precision"] = "0"; + } + } + sessionOptions.AppendExecutionProvider("QNN", options); + } +#endif + else { + throw JSError(runtime, "Unsupported execution provider: " + epName); + } + }); + } + } + } catch (const JSError& e) { + throw e; + } catch (const std::exception& e) { + throw JSError(runtime, + "Failed to parse session options: " + std::string(e.what())); + } +} + +void parseRunOptions(Runtime& runtime, const Value& optionsValue, + Ort::RunOptions& runOptions) { + if (!optionsValue.isObject()) + return; + + auto options = optionsValue.asObject(runtime); + + try { + // tag + if (options.hasProperty(runtime, "tag")) { + auto prop = options.getProperty(runtime, "tag"); + if (prop.isString()) { + std::string tag = prop.asString(runtime).utf8(runtime); + runOptions.SetRunTag(tag.c_str()); + } + } + + // logSeverityLevel + if (options.hasProperty(runtime, "logSeverityLevel")) { + auto prop = options.getProperty(runtime, "logSeverityLevel"); + if (prop.isNumber()) { + int level = static_cast(prop.asNumber()); + if (level >= 0 && level <= 4) { + runOptions.SetRunLogSeverityLevel(level); + } + } + } + + // logVerbosityLevel + if (options.hasProperty(runtime, "logVerbosityLevel")) { + auto prop = options.getProperty(runtime, "logVerbosityLevel"); + if (prop.isNumber()) { + int level = static_cast(prop.asNumber()); + if (level >= 0) { + runOptions.SetRunLogVerbosityLevel(level); + } + } + } + + // terminate + if (options.hasProperty(runtime, "terminate")) { + auto prop = options.getProperty(runtime, "terminate"); + if (prop.isBool() && prop.asBool()) { + runOptions.SetTerminate(); + } + } + + } catch (const std::exception& e) { + throw JSError(runtime, + "Failed to parse run options: " + std::string(e.what())); + } +} + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/SessionUtils.h b/js/react_native/cpp/SessionUtils.h new file mode 100644 index 0000000000000..4dafcd01ab845 --- /dev/null +++ b/js/react_native/cpp/SessionUtils.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include "onnxruntime_cxx_api.h" + +namespace onnxruntimejsi { + +extern const std::vector supportedBackends; + +void parseSessionOptions(facebook::jsi::Runtime& runtime, + const facebook::jsi::Value& optionsValue, + Ort::SessionOptions& sessionOptions); + +void parseRunOptions(facebook::jsi::Runtime& runtime, + const facebook::jsi::Value& optionsValue, + Ort::RunOptions& runOptions); + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/TensorUtils.cpp b/js/react_native/cpp/TensorUtils.cpp new file mode 100644 index 0000000000000..79d270d883294 --- /dev/null +++ b/js/react_native/cpp/TensorUtils.cpp @@ -0,0 +1,236 @@ +#include "TensorUtils.h" +#include "JsiUtils.h" +#include +#include +#include + +using namespace facebook::jsi; + +namespace onnxruntimejsi { + +static const std::unordered_map + dataTypeToStringMap = { + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, "float32"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, "uint8"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, "int8"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, "uint16"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, "int16"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, "int32"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, "int64"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, "string"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, "bool"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, "float16"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, "float64"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, "uint32"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, "uint64"}, +}; + +static const std::unordered_map + elementSizeMap = { + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, sizeof(float)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, sizeof(int8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, sizeof(uint16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, sizeof(int16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, sizeof(int32_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, sizeof(int64_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, sizeof(char*)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, sizeof(bool)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, 2}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, sizeof(double)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, sizeof(uint32_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, sizeof(uint64_t)}, +}; + +static const std::unordered_map + dataTypeToTypedArrayMap = { + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, "Float32Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, "Float64Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, "Int32Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, "BigInt64Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, "Uint32Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, "BigUint64Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, "Uint8Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, "Int8Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, "Uint16Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, "Int16Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, "Float16Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, "Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, "Uint8Array"}, +}; + +inline size_t getElementSize(ONNXTensorElementDataType dataType) { + auto it = elementSizeMap.find(dataType); + if (it != elementSizeMap.end()) { + return it->second; + } + throw std::invalid_argument("Unsupported or unknown tensor data type: " + + std::to_string(static_cast(dataType))); +} + +bool TensorUtils::isTensor(Runtime& runtime, const Object& obj) { + return obj.hasProperty(runtime, "cpuData") && + obj.hasProperty(runtime, "dims") && obj.hasProperty(runtime, "type"); +} + +inline Object getTypedArrayConstructor(Runtime& runtime, + const ONNXTensorElementDataType type) { + auto it = dataTypeToTypedArrayMap.find(type); + if (it != dataTypeToTypedArrayMap.end()) { + auto prop = runtime.global().getProperty(runtime, it->second); + if (prop.isObject()) { + return prop.asObject(runtime); + } else { + throw JSError(runtime, "TypedArray constructor not found: " + + std::string(it->second)); + } + } + throw JSError(runtime, + "Unsupported tensor data type for TypedArray creation: " + + std::to_string(static_cast(type))); +} + +size_t getElementCount(const std::vector& shape) { + size_t count = 1; + for (auto dim : shape) { + count *= dim; + } + return count; +} + +Ort::Value +TensorUtils::createOrtValueFromJSTensor(Runtime& runtime, + const Object& tensorObj, + const Ort::MemoryInfo& memoryInfo) { + if (!isTensor(runtime, tensorObj)) { + throw JSError( + runtime, + "Invalid tensor object: missing cpuData, dims, or type properties"); + } + + auto dataProperty = tensorObj.getProperty(runtime, "cpuData"); + auto dimsProperty = tensorObj.getProperty(runtime, "dims"); + auto typeProperty = tensorObj.getProperty(runtime, "type"); + + if (!dimsProperty.isObject() || + !dimsProperty.asObject(runtime).isArray(runtime)) { + throw JSError(runtime, "Tensor dims must be array"); + } + + if (!typeProperty.isString()) { + throw JSError(runtime, "Tensor type must be string"); + } + + ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + auto typeStr = typeProperty.asString(runtime).utf8(runtime); + for (auto it = dataTypeToStringMap.begin(); it != dataTypeToStringMap.end(); + ++it) { + if (it->second == typeStr) { + type = it->first; + break; + } + } + if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { + throw JSError(runtime, "Unsupported tensor data type: " + typeStr); + } + + void* data = nullptr; + auto dataObj = dataProperty.asObject(runtime); + + if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + if (!dataObj.isArray(runtime)) { + throw JSError(runtime, "Tensor data must be an array of strings"); + } + auto array = dataObj.asArray(runtime); + auto size = array.size(runtime); + data = new char*[size]; + for (size_t i = 0; i < size; ++i) { + auto item = array.getValueAtIndex(runtime, i); + static_cast(data)[i] = + strdup(item.toString(runtime).utf8(runtime).c_str()); + } + } else { + if (!isTypedArray(runtime, dataObj)) { + throw JSError(runtime, "Tensor data must be a TypedArray"); + } + auto buffer = dataObj.getProperty(runtime, "buffer") + .asObject(runtime) + .getArrayBuffer(runtime); + data = buffer.data(runtime); + } + + std::vector shape; + auto dimsArray = dimsProperty.asObject(runtime).asArray(runtime); + for (size_t i = 0; i < dimsArray.size(runtime); ++i) { + auto dim = dimsArray.getValueAtIndex(runtime, i); + if (dim.isNumber()) { + shape.push_back(static_cast(dim.asNumber())); + } + } + + return Ort::Value::CreateTensor(memoryInfo, data, + getElementCount(shape) * getElementSize(type), + shape.data(), shape.size(), type); +} + +Object +TensorUtils::createJSTensorFromOrtValue(Runtime& runtime, Ort::Value& ortValue, + const Object& tensorConstructor) { + auto typeInfo = ortValue.GetTensorTypeAndShapeInfo(); + auto shape = typeInfo.GetShape(); + auto elementType = typeInfo.GetElementType(); + + std::string typeStr; + auto it = dataTypeToStringMap.find(elementType); + if (it != dataTypeToStringMap.end()) { + typeStr = it->second; + } else { + throw JSError(runtime, + "Unsupported tensor data type for TypedArray creation: " + + std::to_string(static_cast(elementType))); + } + + auto dimsArray = Array(runtime, shape.size()); + for (size_t j = 0; j < shape.size(); ++j) { + dimsArray.setValueAtIndex(runtime, j, Value(static_cast(shape[j]))); + } + + if (elementType != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + void* rawData = ortValue.GetTensorMutableRawData(); + size_t elementCount = + ortValue.GetTensorTypeAndShapeInfo().GetElementCount(); + size_t elementSize = getElementSize(elementType); + size_t dataSize = elementCount * elementSize; + + auto typedArrayCtor = getTypedArrayConstructor(runtime, elementType); + auto typedArrayInstance = + typedArrayCtor.asFunction(runtime).callAsConstructor( + runtime, static_cast(elementCount)); + + auto buffer = typedArrayInstance.asObject(runtime) + .getProperty(runtime, "buffer") + .asObject(runtime) + .getArrayBuffer(runtime); + memcpy(buffer.data(runtime), rawData, dataSize); + + auto tensorInstance = + tensorConstructor.asFunction(runtime).callAsConstructor( + runtime, typeStr, typedArrayInstance, dimsArray); + + return tensorInstance.asObject(runtime); + } else { + auto strArray = Array(runtime, shape.size()); + for (size_t j = 0; j < shape.size(); ++j) { + strArray.setValueAtIndex( + runtime, j, Value(runtime, String::createFromUtf8(runtime, ""))); + } + + auto tensorInstance = + tensorConstructor.asFunction(runtime).callAsConstructor( + runtime, typeStr, strArray, dimsArray); + + return tensorInstance.asObject(runtime); + } +} + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/TensorUtils.h b/js/react_native/cpp/TensorUtils.h new file mode 100644 index 0000000000000..5361f5cb1101f --- /dev/null +++ b/js/react_native/cpp/TensorUtils.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include "onnxruntime_cxx_api.h" +#include +#include + +namespace onnxruntimejsi { + +class TensorUtils { + public: + static Ort::Value + createOrtValueFromJSTensor(facebook::jsi::Runtime& runtime, + const facebook::jsi::Object& tensorObj, + const Ort::MemoryInfo& memoryInfo); + + static facebook::jsi::Object + createJSTensorFromOrtValue(facebook::jsi::Runtime& runtime, + Ort::Value& ortValue, + const facebook::jsi::Object& tensorConstructor); + + static bool isTensor(facebook::jsi::Runtime& runtime, + const facebook::jsi::Object& obj); +}; + +} // namespace onnxruntimejsi diff --git a/js/react_native/e2e/android/app/build.gradle b/js/react_native/e2e/android/app/build.gradle index fa94c00a32bd0..54d5e55a209d8 100644 --- a/js/react_native/e2e/android/app/build.gradle +++ b/js/react_native/e2e/android/app/build.gradle @@ -116,17 +116,10 @@ android { } } -repositories { - flatDir { - dir 'libs' - } -} - dependencies { androidTestImplementation('com.wix:detox:+') implementation 'androidx.appcompat:appcompat:1.1.0' - implementation fileTree(dir: "libs", include: ["*.jar"]) // The version of react-native is set by the React Native Gradle Plugin implementation("com.facebook.react:react-android") implementation("com.facebook.react:flipper-integration") @@ -143,8 +136,6 @@ dependencies { androidTestImplementation "androidx.test:rules:1.5.0" implementation (project(':onnxruntime-react-native')) - // specify ORT dependency here so it can be found in libs flatDir repository - implementation "com.microsoft.onnxruntime:onnxruntime-android:latest.integration@aar" } // Run this once to be able to run the application with BUCK diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_bool.onnx b/js/react_native/e2e/android/app/src/main/assets/test_types_bool.onnx similarity index 100% rename from js/react_native/android/src/androidTest/res/raw/test_types_bool.onnx rename to js/react_native/e2e/android/app/src/main/assets/test_types_bool.onnx diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_double.onnx b/js/react_native/e2e/android/app/src/main/assets/test_types_double.onnx similarity index 100% rename from js/react_native/android/src/androidTest/res/raw/test_types_double.onnx rename to js/react_native/e2e/android/app/src/main/assets/test_types_double.onnx diff --git a/js/react_native/e2e/android/app/src/main/assets/test_types_float.ort b/js/react_native/e2e/android/app/src/main/assets/test_types_float.ort new file mode 100644 index 0000000000000..2e8377a4afc1f Binary files /dev/null and b/js/react_native/e2e/android/app/src/main/assets/test_types_float.ort differ diff --git a/js/react_native/e2e/android/app/src/main/assets/test_types_int32.ort b/js/react_native/e2e/android/app/src/main/assets/test_types_int32.ort new file mode 100644 index 0000000000000..15d3cc1f8903f Binary files /dev/null and b/js/react_native/e2e/android/app/src/main/assets/test_types_int32.ort differ diff --git a/js/react_native/e2e/android/app/src/main/assets/test_types_int64.ort b/js/react_native/e2e/android/app/src/main/assets/test_types_int64.ort new file mode 100644 index 0000000000000..e0cbaa9d86392 Binary files /dev/null and b/js/react_native/e2e/android/app/src/main/assets/test_types_int64.ort differ diff --git a/js/react_native/e2e/android/app/src/main/assets/test_types_int8.ort b/js/react_native/e2e/android/app/src/main/assets/test_types_int8.ort new file mode 100644 index 0000000000000..9d2ef52138cfc Binary files /dev/null and b/js/react_native/e2e/android/app/src/main/assets/test_types_int8.ort differ diff --git a/js/react_native/e2e/android/app/src/main/assets/test_types_uint8.ort b/js/react_native/e2e/android/app/src/main/assets/test_types_uint8.ort new file mode 100644 index 0000000000000..a0a5d6a1d0177 Binary files /dev/null and b/js/react_native/e2e/android/app/src/main/assets/test_types_uint8.ort differ diff --git a/js/react_native/e2e/android/app/src/main/java/com/reactnativeonnxruntimemodule/MNISTDataHandler.java b/js/react_native/e2e/android/app/src/main/java/com/reactnativeonnxruntimemodule/MNISTDataHandler.java index b5f58f39ea8ca..72f27b3291a8c 100644 --- a/js/react_native/e2e/android/app/src/main/java/com/reactnativeonnxruntimemodule/MNISTDataHandler.java +++ b/js/react_native/e2e/android/app/src/main/java/com/reactnativeonnxruntimemodule/MNISTDataHandler.java @@ -5,8 +5,6 @@ import static java.util.stream.Collectors.joining; -import ai.onnxruntime.reactnative.OnnxruntimeModule; -import ai.onnxruntime.reactnative.TensorHelper; import android.content.Context; import android.graphics.Bitmap; import android.graphics.BitmapFactory; @@ -149,7 +147,7 @@ private WritableMap preprocess(String uri) throws Exception { inputTensorMap.putArray("dims", dims); // type - inputTensorMap.putString("type", TensorHelper.JsTensorTypeFloat); + inputTensorMap.putString("type", "float32"); // data encoded as Base64 imageByteBuffer.rewind(); diff --git a/js/react_native/e2e/android/build.gradle b/js/react_native/e2e/android/build.gradle index 9ad8256fc52dc..8d1f9d59d8649 100644 --- a/js/react_native/e2e/android/build.gradle +++ b/js/react_native/e2e/android/build.gradle @@ -39,6 +39,10 @@ allprojects { // Add Detox as a precompiled native dependency url("$rootDir/../node_modules/detox/Detox-android") } + maven { + // Local onnxruntime-android package + url("$rootDir/app/libs") + } google() mavenCentral() @@ -46,4 +50,4 @@ allprojects { } } -apply plugin: "com.facebook.react.rootproject" \ No newline at end of file +apply plugin: "com.facebook.react.rootproject" diff --git a/js/react_native/e2e/android/gradle.properties b/js/react_native/e2e/android/gradle.properties index ede6147623f19..a8840c7b7d214 100644 --- a/js/react_native/e2e/android/gradle.properties +++ b/js/react_native/e2e/android/gradle.properties @@ -22,4 +22,6 @@ android.enableJetifier=true org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=2048m -Dkotlin.daemon.jvm.options=-Xmx8192m # Use this property to enable or disable the Hermes JS engine. # If set to false, you will be using JSC instead. -hermesEnabled=false +hermesEnabled=true + +reactNativeArchitectures=x86_64 diff --git a/js/react_native/e2e/ios/MNISTDataHandler.h b/js/react_native/e2e/ios/MNISTDataHandler.h index da05843e8a41f..595cae82ea91c 100644 --- a/js/react_native/e2e/ios/MNISTDataHandler.h +++ b/js/react_native/e2e/ios/MNISTDataHandler.h @@ -4,7 +4,7 @@ #ifndef MNISTDataHandler_h #define MNISTDataHandler_h -#import +#import @interface MNISTDataHandler : NSObject @end diff --git a/js/react_native/e2e/ios/MNISTDataHandler.mm b/js/react_native/e2e/ios/MNISTDataHandler.mm index 1a79b66ca5d2f..6c27607eff1ed 100644 --- a/js/react_native/e2e/ios/MNISTDataHandler.mm +++ b/js/react_native/e2e/ios/MNISTDataHandler.mm @@ -2,10 +2,9 @@ // Licensed under the MIT License. #import "MNISTDataHandler.h" -#import "OnnxruntimeModule.h" -#import "TensorHelper.h" #import #import +#include NS_ASSUME_NONNULL_BEGIN @@ -119,7 +118,7 @@ - (NSDictionary*)preprocess:(NSString*)uri { inputTensorMap[@"dims"] = dims; // type - inputTensorMap[@"type"] = JsTensorTypeFloat; + inputTensorMap[@"type"] = @"float32"; // encoded data NSString* data = [byteBufferRef base64EncodedStringWithOptions:0]; diff --git a/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj b/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj index 6f957af603385..70a5fcdd33cad 100644 --- a/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj +++ b/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj @@ -10,6 +10,13 @@ 13B07FBC1A68108700A75B9A /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 13B07FB01A68108700A75B9A /* AppDelegate.m */; }; 13B07FBF1A68108700A75B9A /* Images.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 13B07FB51A68108700A75B9A /* Images.xcassets */; }; 13B07FC11A68108700A75B9A /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 13B07FB71A68108700A75B9A /* main.m */; }; + 3ADD0A3C2EBB64D200761D6F /* ../src/test_types_int8.ort in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A382EBB64D200761D6F /* ../src/test_types_int8.ort */; }; + 3ADD0A3D2EBB64D200761D6F /* ../src/test_types_int64.ort in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A3A2EBB64D200761D6F /* ../src/test_types_int64.ort */; }; + 3ADD0A3E2EBB64D200761D6F /* ../src/test_types_int32.ort in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A392EBB64D200761D6F /* ../src/test_types_int32.ort */; }; + 3ADD0A3F2EBB64D200761D6F /* ../src/test_types_float.ort in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A372EBB64D200761D6F /* ../src/test_types_float.ort */; }; + 3ADD0A402EBB64D200761D6F /* ../src/test_types_uint8.ort in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A3B2EBB64D200761D6F /* ../src/test_types_uint8.ort */; }; + 3ADD0A422EBB677300761D6F /* test_types_double.onnx in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A412EBB677300761D6F /* test_types_double.onnx */; }; + 3ADD0A442EBB679A00761D6F /* test_types_bool.onnx in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A432EBB679A00761D6F /* test_types_bool.onnx */; }; 81AB9BB82411601600AC10FF /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 81AB9BB72411601600AC10FF /* LaunchScreen.storyboard */; }; DB61BA27278684FB0096C971 /* OnnxruntimeModuleExampleUITests.m in Sources */ = {isa = PBXBuildFile; fileRef = DB61BA26278684FB0096C971 /* OnnxruntimeModuleExampleUITests.m */; }; DBA8BA87267293C4008CC55A /* mnist.ort in Resources */ = {isa = PBXBuildFile; fileRef = DBA8BA86267293C4008CC55A /* mnist.ort */; }; @@ -50,6 +57,13 @@ 13B07FB51A68108700A75B9A /* Images.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; name = Images.xcassets; path = OnnxruntimeModuleExample/Images.xcassets; sourceTree = ""; }; 13B07FB61A68108700A75B9A /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; name = Info.plist; path = OnnxruntimeModuleExample/Info.plist; sourceTree = ""; }; 13B07FB71A68108700A75B9A /* main.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = main.m; path = OnnxruntimeModuleExample/main.m; sourceTree = ""; }; + 3ADD0A372EBB64D200761D6F /* ../src/test_types_float.ort */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_float.ort; sourceTree = ""; }; + 3ADD0A382EBB64D200761D6F /* ../src/test_types_int8.ort */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_int8.ort; sourceTree = ""; }; + 3ADD0A392EBB64D200761D6F /* ../src/test_types_int32.ort */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_int32.ort; sourceTree = ""; }; + 3ADD0A3A2EBB64D200761D6F /* ../src/test_types_int64.ort */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_int64.ort; sourceTree = ""; }; + 3ADD0A3B2EBB64D200761D6F /* ../src/test_types_uint8.ort */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_uint8.ort; sourceTree = ""; }; + 3ADD0A412EBB677300761D6F /* test_types_double.onnx */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_double.onnx; sourceTree = ""; }; + 3ADD0A432EBB679A00761D6F /* test_types_bool.onnx */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_bool.onnx; sourceTree = ""; }; 81AB9BB72411601600AC10FF /* LaunchScreen.storyboard */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.storyboard; name = LaunchScreen.storyboard; path = OnnxruntimeModuleExample/LaunchScreen.storyboard; sourceTree = ""; }; 9D58C0FCCF00905433F4ED74 /* Pods-OnnxruntimeModuleExample.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModuleExample.debug.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModuleExample/Pods-OnnxruntimeModuleExample.debug.xcconfig"; sourceTree = ""; }; B70FCE6DFAB320E9051DA321 /* Pods-OnnxruntimeModuleExample.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModuleExample.release.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModuleExample/Pods-OnnxruntimeModuleExample.release.xcconfig"; sourceTree = ""; }; @@ -128,6 +142,13 @@ 83CBB9F61A601CBA00E9B192 = { isa = PBXGroup; children = ( + 3ADD0A432EBB679A00761D6F /* test_types_bool.onnx */, + 3ADD0A412EBB677300761D6F /* test_types_double.onnx */, + 3ADD0A372EBB64D200761D6F /* ../src/test_types_float.ort */, + 3ADD0A382EBB64D200761D6F /* ../src/test_types_int8.ort */, + 3ADD0A392EBB64D200761D6F /* ../src/test_types_int32.ort */, + 3ADD0A3A2EBB64D200761D6F /* ../src/test_types_int64.ort */, + 3ADD0A3B2EBB64D200761D6F /* ../src/test_types_uint8.ort */, DBA8BA86267293C4008CC55A /* mnist.ort */, DBBF7413263B8CCB00487C77 /* 3.jpg */, 13B07FAE1A68108700A75B9A /* OnnxruntimeModuleExample */, @@ -247,6 +268,13 @@ DBA8BA87267293C4008CC55A /* mnist.ort in Resources */, DBBF7414263B8CCB00487C77 /* 3.jpg in Resources */, 81AB9BB82411601600AC10FF /* LaunchScreen.storyboard in Resources */, + 3ADD0A422EBB677300761D6F /* test_types_double.onnx in Resources */, + 3ADD0A3C2EBB64D200761D6F /* ../src/test_types_int8.ort in Resources */, + 3ADD0A442EBB679A00761D6F /* test_types_bool.onnx in Resources */, + 3ADD0A3D2EBB64D200761D6F /* ../src/test_types_int64.ort in Resources */, + 3ADD0A3E2EBB64D200761D6F /* ../src/test_types_int32.ort in Resources */, + 3ADD0A3F2EBB64D200761D6F /* ../src/test_types_float.ort in Resources */, + 3ADD0A402EBB64D200761D6F /* ../src/test_types_uint8.ort in Resources */, E329E1162D3728940016B599 /* PrivacyInfo.xcprivacy in Resources */, 13B07FBF1A68108700A75B9A /* Images.xcassets in Resources */, ); diff --git a/js/react_native/e2e/ios/test_types_bool.ort b/js/react_native/e2e/ios/test_types_bool.ort new file mode 100644 index 0000000000000..ee955dcc6fe54 Binary files /dev/null and b/js/react_native/e2e/ios/test_types_bool.ort differ diff --git a/js/react_native/e2e/ios/test_types_double.ort b/js/react_native/e2e/ios/test_types_double.ort new file mode 100644 index 0000000000000..0259d0eae66ab Binary files /dev/null and b/js/react_native/e2e/ios/test_types_double.ort differ diff --git a/js/react_native/e2e/metro.config.js b/js/react_native/e2e/metro.config.js index 9f279f35616a3..e9ef3a02f075a 100644 --- a/js/react_native/e2e/metro.config.js +++ b/js/react_native/e2e/metro.config.js @@ -12,6 +12,7 @@ const config = { ], resolver: { sourceExts: ['tsx', 'ts', 'jsx', 'js', 'json'], // Ensure TypeScript files are recognized + assetExts: ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'tiff', 'ico', 'webp', 'svg', 'ort', 'onnx'], }, }; module.exports = mergeConfig(getDefaultConfig(__dirname), config); diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json index 54a78bf52a15e..931b86b3b08b1 100644 --- a/js/react_native/e2e/package-lock.json +++ b/js/react_native/e2e/package-lock.json @@ -2385,9 +2385,9 @@ } }, "node_modules/@eslint/eslintrc/node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "license": "MIT", "dependencies": { @@ -6639,9 +6639,9 @@ } }, "node_modules/eslint/node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "license": "MIT", "dependencies": { @@ -9028,7 +9028,9 @@ "license": "MIT" }, "node_modules/js-yaml": { - "version": "3.14.1", + "version": "3.14.2", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.2.tgz", + "integrity": "sha512-PMSmkqxr106Xa156c2M265Z+FTrPl+oxd/rgOQy2tijQeK5TxQ43psO1ZCwhVOSdnn+RzkzlRz/eY4BgJBYVpg==", "license": "MIT", "dependencies": { "argparse": "^1.0.7", diff --git a/js/react_native/e2e/src/App.tsx b/js/react_native/e2e/src/App.tsx index 39c496062f665..438073f864c42 100644 --- a/js/react_native/e2e/src/App.tsx +++ b/js/react_native/e2e/src/App.tsx @@ -2,22 +2,65 @@ // Licensed under the MIT License. import * as React from 'react'; -import { Image, Text, TextInput, View } from 'react-native'; -// onnxruntime-react-native package is installed when bootstraping -import { InferenceSession, Tensor } from 'onnxruntime-react-native'; -import MNIST, { MNISTInput, MNISTOutput, MNISTResult, } from './mnist-data-handler'; -import { Buffer } from 'buffer'; -import { readFile } from 'react-native-fs'; +import { Button, SafeAreaView, ScrollView, StyleSheet, Text, View } from 'react-native'; +import MNISTTest from './MNISTTest'; +import BasicTypesTest from './BasicTypesTest'; + +type Page = 'home' | 'mnist' | 'basic-types'; interface State { - session: - InferenceSession | null; - output: - string | null; - imagePath: - string | null; + currentPage: Page; } +const styles = StyleSheet.create({ + container: { + flex: 1, + backgroundColor: '#f5f5f5', + }, + scrollContent: { + padding: 20, + alignItems: 'center', + }, + title: { + fontSize: 28, + fontWeight: 'bold', + marginTop: 20, + marginBottom: 10, + color: '#333', + textAlign: 'center', + }, + subtitle: { + fontSize: 18, + marginBottom: 30, + color: '#666', + textAlign: 'center', + }, + buttonContainer: { + width: '100%', + marginBottom: 30, + alignItems: 'center', + }, + buttonWrapper: { + width: '80%', + marginBottom: 10, + }, + description: { + fontSize: 14, + color: '#888', + textAlign: 'center', + paddingHorizontal: 20, + }, + header: { + padding: 10, + backgroundColor: '#fff', + borderBottomWidth: 1, + borderBottomColor: '#ddd', + }, + testContent: { + flex: 1, + }, +}); + // eslint-disable-next-line @typescript-eslint/no-empty-object-type export default class App extends React.PureComponent<{}, State> { // eslint-disable-next-line @typescript-eslint/no-empty-object-type @@ -25,104 +68,89 @@ export default class App extends React.PureComponent<{}, State> { super(props); this.state = { - session: null, - output: null, - imagePath: null, + currentPage: 'home', }; } - // Load a model when an app is loading - async componentDidMount(): Promise { - if (!this.state.session) { - try { - const imagePath = await MNIST.getImagePath(); - this.setState({ imagePath }); - - const modelPath = await MNIST.getLocalModelPath(); - - // test creating session with path - console.log('Creating with path'); - const pathSession: InferenceSession = await InferenceSession.create(modelPath); - void pathSession.release(); - - // and with bytes - console.log('Creating with bytes'); - const base64Str = await readFile(modelPath, 'base64'); - const bytes = Buffer.from(base64Str, 'base64'); - const session: InferenceSession = await InferenceSession.create(bytes); - this.setState({ session }); - - console.log('Test session created'); - void await this.infer(); - } catch (err) { - console.log(err.message); - } - } + navigateTo = (page: Page) => { + this.setState({ currentPage: page }); + }; + + renderHome(): React.JSX.Element { + return ( + + + ONNX Runtime E2E Tests + Select a test to run: + + + +