From 1f0067968f8c82affef8ead63fb16118da16e39a Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Fri, 14 Nov 2025 12:02:28 -0800 Subject: [PATCH 01/15] Revert to use setup-dotnet@v3 for x86 release pipeline (#26559) ### Description Revert to use setup-dotnet@v3. Using actions/setup-dotnet@v5 is having issue as it keeps using latest dotnet 10.0.0 that makes pipeline failed. ### Motivation and Context --- .github/workflows/windows_x86.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index d81c5d559c8e5..d20778d56f60b 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -61,7 +61,7 @@ jobs: working-directory: ${{ github.workspace }} - name: Use .NET 8.x - uses: actions/setup-dotnet@v5 + uses: actions/setup-dotnet@v3 with: dotnet-version: '8.x' env: From a7df7b19a0c80695909e8e845d9c913d6fef53dd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 15 Nov 2025 00:07:35 +0000 Subject: [PATCH 02/15] Bump js-yaml from 4.1.0 to 4.1.1 in /js (#26577) Bumps [js-yaml](https://github.com/nodeca/js-yaml) from 4.1.0 to 4.1.1.
Changelog

Sourced from js-yaml's changelog.

[4.1.1] - 2025-11-12

Security

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=js-yaml&package-manager=npm_and_yarn&previous-version=4.1.0&new-version=4.1.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/package-lock.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/js/package-lock.json b/js/package-lock.json index a13f1ae373f4b..1e9f5cb29fe6c 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -4020,9 +4020,9 @@ } }, "node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "dependencies": { "argparse": "^2.0.1" @@ -8555,9 +8555,9 @@ } }, "js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "requires": { "argparse": "^2.0.1" From d6a372aaf7233554aff09c16243eb2becc7b5ed9 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 14 Nov 2025 16:17:50 -0800 Subject: [PATCH 03/15] qgemm: optimize avxvnni QGEMM inner kernel for M=1 (#22952) Add specialized path for M=1 case that exploits additional available ymm registers for deeper inner kernel loop unrolling. Performance impact (measured on 13th Gen Intel(R) Core(TM) i9-13900K): - 30% improvement in single threaded QGEMM kernels with M = 1 - 7% reduction in average inference time on small quantized model where all kernels have M=1 ``` |--------------------------------------------------------------------+--------+---------+----------+----------+---------+---------| | Benchmark | Time | CPU | Time Old | Time New | CPU Old | CPU New | |--------------------------------------------------------------------+--------+---------+----------+----------+---------+---------| | QGEMM/UnsignedAPackB/M:1/N:512/K:512/Batch:1/Threads:1/real_time | -0.275 | -0.2756 | 4330 | 3137 | 4330 | 3136 | | QGEMM/UnsignedAPackB/M:1/N:512/K:1024/Batch:1/Threads:1/real_time | -0.292 | -0.2927 | 9027 | 6385 | 9027 | 6385 | | QGEMM/UnsignedAPackB/M:1/N:1024/K:1024/Batch:1/Threads:1/real_time | -0.300 | -0.3005 | 17867 | 12499 | 17866 | 12498 | | OVERALL_GEOMEAN | -0.289 | -0.2897 | | | | | |--------------------------------------------------------------------+--------+---------+----------+----------+---------+---------| ``` --------- Co-authored-by: Raghuveer Devulapalli --- .../mlas/lib/amd64/QgemmU8X8KernelAvx2.asm | 308 ++++++++++++++--- onnxruntime/core/mlas/lib/amd64/mlasi.inc | 9 + .../mlas/lib/x86_64/QgemmU8X8KernelAvx2.S | 315 ++++++++++++++---- onnxruntime/core/mlas/lib/x86_64/asmmacro.h | 22 ++ onnxruntime/test/mlas/bench/bench_qgemm.cpp | 3 + 5 files changed, 558 insertions(+), 99 deletions(-) diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm index 1705a15fa4dc7..e65e43d93e671 100644 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm @@ -41,7 +41,7 @@ GemmInt8KernelFrame STRUCT SavedXmm13 OWORD ? SavedXmm14 OWORD ? SavedXmm15 OWORD ? - Padding QWORD ? + SavedR14 QWORD ? SavedR13 QWORD ? SavedR12 QWORD ? SavedRdi QWORD ? @@ -165,6 +165,42 @@ ENDIF ENDM +; Macro Description: +; +; This macro generates the appropriate vpdp instruction based on the ASigned +; and BSigned values. +; +; Arguments: +; +; ASigned - sign of A. +; +; BSigned - sign of B. +; +; reg1 - Output register for vpdp instruction +; +; reg2 - Second input register for vpdp instruction +; +; reg3 - First input register for vpdp instruction +; + +VpdpYmmYmmYmm MACRO ASigned, BSigned, reg1, reg2, reg3 + + IF ASigned EQ 1 + IF BSigned EQ 1 + VpdpbssdYmmYmmYmm reg1, reg2, reg3 + ELSE + VpdpbsudYmmYmmYmm reg1, reg2, reg3 + ENDIF + ELSE + IF BSigned EQ 1 + VpdpbusdYmmYmmYmm reg1, reg2, reg3 + ELSE + VpdpbuudYmmYmmYmm reg1, reg2, reg3 + ENDIF + ENDIF + + ENDM + ; ; Macro Description: ; @@ -190,41 +226,21 @@ ENDIF ; ymm2 - Supplies the broadcast value loaded from matrix A. ; -MultiplyAccumulateRowAvxVnni MACRO ColumnCount, Vec1Reg, Vec2Reg, ASigned, BSigned +MultiplyAccumulateRowAvxVnni MACRO ColumnCount, ASigned, BSigned, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg -IF ASigned EQ 1 - IF BSigned EQ 1 - IF ColumnCount EQ 16 - VpdpbssdYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbssdYmmYmmYmm Vec2Reg,ymm2,ymm1 - ELSE - VpdpbssdYmmYmmYmm Vec2Reg,ymm2,ymm0 + IF ColumnCount EQ 32 + VpdpYmmYmmYmm ASigned, BSigned, Vec1Reg, ymm2, ymm0 + VpdpYmmYmmYmm ASigned, BSigned, Vec2Reg, ymm2, ymm1 + VpdpYmmYmmYmm ASigned, BSigned, Vec3Reg, ymm2, ymm14 + VpdpYmmYmmYmm ASigned, BSigned, Vec4Reg, ymm2, ymm15 ENDIF - ELSE IF ColumnCount EQ 16 - VpdpbsudYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbsudYmmYmmYmm Vec2Reg,ymm2,ymm1 - ELSE - VpdpbsudYmmYmmYmm Vec2Reg,ymm2,ymm0 + VpdpYmmYmmYmm ASigned, BSigned, Vec1Reg, ymm2, ymm0 + VpdpYmmYmmYmm ASigned, BSigned, Vec2Reg, ymm2, ymm1 ENDIF - ENDIF -ELSE - IF BSigned EQ 1 - IF ColumnCount EQ 16 - VpdpbusdYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbusdYmmYmmYmm Vec2Reg,ymm2,ymm1 - ELSE - VpdpbusdYmmYmmYmm Vec2Reg,ymm2,ymm0 + IF ColumnCount EQ 8 + VpdpYmmYmmYmm ASigned, BSigned, Vec2Reg, ymm2, ymm0 ENDIF - ELSE - IF ColumnCount EQ 16 - VpdpbuudYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbuudYmmYmmYmm Vec2Reg,ymm2,ymm1 - ELSE - VpdpbuudYmmYmmYmm Vec2Reg,ymm2,ymm0 - ENDIF - ENDIF -ENDIF ENDM @@ -261,18 +277,20 @@ ComputeBlockAvxVnni MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset, vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset] EmitIfCountGE ColumnCount, 16, + EmitIfCount2EQ ColumnCount, 32, RowCount, 1, + EmitIfCount2EQ ColumnCount, 32, RowCount, 1, EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, ENDM @@ -312,7 +330,8 @@ ComputeBlockLoop MACRO Isa, ColumnCount, RowCount, ASigned, BSigned mov rsi,r9 ; reload row length remaining -IF (ColumnCount EQ 16) AND (RowCount EQ 1) +IF (ColumnCount EQ 16) OR (ColumnCount EQ 32) +IF (RowCount EQ 1) sub rsi,4*4 jb ProcessRemainingBlocks @@ -329,7 +348,8 @@ ComputeBlockBy4Loop: ProcessRemainingBlocks: add rsi,4*4 ; correct for over-subtract above jz ComputeBlockLoopExit -ENDIF +ENDIF ; RowCount == 1 +ENDIF ; ColumnCount == 16/32 ComputeBlockBy1Loop: ComputeBlock&Isa& ColumnCount, RowCount, 0, 0, ASigned, BSigned @@ -552,24 +572,44 @@ ProduceOutputBlock MACRO ColumnCount, RowCount, ASigned, BSigned EmitIfCountGE RowCount, 4, EmitIfCountGE RowCount, 5, EmitIfCountGE RowCount, 6, +IF ColumnCount EQ 32 + vmovdqu ymm0,YMMWORD PTR [r12] + vmovdqu ymm1,YMMWORD PTR [r12+32] + vmovdqu ymm14,YMMWORD PTR [r12+64] + vmovdqu ymm15,YMMWORD PTR [r12+96] + add r12,32*4 ; advance ColumnSumBuffer by 32 columns +ENDIF IF ColumnCount EQ 16 vmovdqu ymm0,YMMWORD PTR [r12] vmovdqu ymm1,YMMWORD PTR [r12+32] add r12,16*4 ; advance ColumnSumBuffer by 16 columns -ELSE +ENDIF +IF ColumnCount EQ 8 vmovdqu ymm1,YMMWORD PTR [r12] ENDIF test r13,r13 ; per column zero points? jz SkipScaleByZeroPointB +IF ColumnCount EQ 32 + vmovdqu ymm2,YMMWORD PTR [r13] + vmovdqu ymm3,YMMWORD PTR [r13+32] + vmovdqu ymm12,YMMWORD PTR [r13+64] + vmovdqu ymm13,YMMWORD PTR [r13+96] + add r13,32*4 ; advance ZeroPointB by 16 columns +ENDIF IF ColumnCount EQ 16 vmovdqu ymm2,YMMWORD PTR [r13] vmovdqu ymm3,YMMWORD PTR [r13+32] add r13,16*4 ; advance ZeroPointB by 16 columns -ELSE +ENDIF +IF ColumnCount EQ 8 vmovdqu ymm3,YMMWORD PTR [r13] ENDIF + EmitIfCount2EQ RowCount, 1, ColumnCount, 32, + EmitIfCount2EQ RowCount, 1, ColumnCount, 32, EmitIfCount2GE RowCount, 1, ColumnCount, 16, EmitIfCountGE RowCount, 1, + EmitIfCount2EQ RowCount, 1, ColumnCount, 32, + EmitIfCount2EQ RowCount, 1, ColumnCount, 32, EmitIfCount2GE RowCount, 1, ColumnCount, 16, EmitIfCountGE RowCount, 1, EmitIfCount2GE RowCount, 2, ColumnCount, 16, @@ -595,6 +635,8 @@ ENDIF jmp AccumulatorsInitialized SkipScaleByZeroPointB: + EmitIfCount2EQ RowCount, 1, ColumnCount, 32, + EmitIfCount2EQ RowCount, 1, ColumnCount, 32, EmitIfCount2GE RowCount, 1, ColumnCount, 16, EmitIfCountGE RowCount, 1, EmitIfCount2GE RowCount, 2, ColumnCount, 16, @@ -810,6 +852,177 @@ SkipAccumulateOutputMasked8xNBlock: ENDM +; +; Section Description: +; +; This macro generates code to compute matrix multiplication for a single +; row. When processing just one row, there are more ymm registers available +; for us to unroll the main kernel further to benefit from better pipelining +; the dot product instruction. +; +; Arguments: None +; +; Implicit Arguments: Same as ProcessCountM +; +; + +ProcessCount1AvxVnni MACRO RowCount, ASigned, BSigned, Fallthrough + + LOCAL LProcessNextColumnLoop32xN1 + LOCAL LSkipAccumulateOutputMasked32xNBlock1 + LOCAL LProcessNextColumnLoop16xN1 + LOCAL LSkipAccumulateOutput16xNBlock1 + LOCAL LProcessRemainingCountN1 + LOCAL LSkipAccumulateOutput8xNBlock1 + LOCAL LExitProcessCountM1 + LOCAL LOutputMasked32xNBlock1 + LOCAL LSkipAccumulateOutputMasked32xNBlock1 + LOCAL LOutputMasked24xNBlock1 + LOCAL LSkipAccumulateOutputMasked24xNBlock1 + LOCAL LOutputMasked16xNBlock1 + LOCAL LSkipAccumulateOutputMasked16xNBlock1 + LOCAL LOutputMasked8xNBlock1 + LOCAL LSkipAccumulateOutputMasked8xNBlock1 + + cmp rbp,8 + jbe LProcessRemainingCountN1 ; num of cols <= 8?: process the tail + cmp rbp,16 + jbe LProcessNextColumnLoop16xN1 ; num of cols <= 16?: process 16 at a time: + +LProcessNextColumnLoop32xN1: ; Ouptut look to process 32 cols at a time: + ProduceOutputBlock 32, 1, ASigned, BSigned + add rdx,r14 + sub rbp,32 + jb LOutputMasked32xNBlock1 ; if numcols < 32 (& > 16), use write using masked output and exit + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutput32xNBlock1 + vpaddd ymm4,ymm4,YMMWORD PTR [r8] + vpaddd ymm5,ymm5,YMMWORD PTR [r8+32] + vpaddd ymm6,ymm6,YMMWORD PTR [r8+64] + vpaddd ymm7,ymm7,YMMWORD PTR [r8+96] + +LSkipAccumulateOutput32xNBlock1: + vmovdqu YMMWORD PTR [r8],ymm4 + vmovdqu YMMWORD PTR [r8+32],ymm5 + vmovdqu YMMWORD PTR [r8+64],ymm6 + vmovdqu YMMWORD PTR [r8+96],ymm7 + add r8,32*4 ; advance matrix C by 32 columns + mov rcx,rdi ; reload matrix A + cmp rbp,0 + je LExitProcessCountM1 + cmp rbp,8 + jle LProcessRemainingCountN1 ; num of cols < 8 + cmp rbp,16 + ja LProcessNextColumnLoop32xN1 ; num of cols > 16?: process 32 at a time: + +LProcessNextColumnLoop16xN1: ; num of cols > 8 and <= 16 + ProduceOutputBlock 16, 1, ASigned, BSigned + sub rbp,16 + jb LOutputMasked16xNBlock1 ; if numcols < 16 (& > 8), use write using masked output and exit + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutput16xNBlock1 + vpaddd ymm4,ymm4,YMMWORD PTR [r8] + vpaddd ymm5,ymm5,YMMWORD PTR [r8+32] + +LSkipAccumulateOutput16xNBlock1: + vmovdqu YMMWORD PTR [r8],ymm4 + vmovdqu YMMWORD PTR [r8+32],ymm5 + add r8,16*4 ; advance matrix C by 16 columns + mov rcx,rdi ; reload matrix A + cmp rbp,0 + je LExitProcessCountM1 + cmp rbp,8 + ja LProcessNextColumnLoop16xN1 ; num of cols > 8?: process 16 at a time: + +; Loop if num of cols <= 8 +LProcessRemainingCountN1: + ProduceOutputBlock 8, 1, ASigned, BSigned + cmp rbp,8 + jb LOutputMasked8xNBlock1 ; if numcols < 8, use write using masked output and exit + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutput8xNBlock1 + vpaddd ymm5,ymm5,YMMWORD PTR [r8] + +LSkipAccumulateOutput8xNBlock1: + vmovdqu YMMWORD PTR [r8],ymm5 + +LExitProcessCountM1: ; num of cols = 0, we are done + mov eax, 1 + jmp ExitKernel + +;; -- Section to write final tail of C matrix and exit -- ;; +;; write <= 32 elements ;; +LOutputMasked32xNBlock1: + add rbp,32 + cmp rbp,24 + jle LOutputMasked24xNBlock1 + sub rbp,24 + neg rbp + lea rcx,MlasMaskMoveTableAvx+8*4 + vmovdqu ymm0,YMMWORD PTR [rcx+rbp*4] + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutputMasked32xNBlock1 + vpaddd ymm4,ymm4,YMMWORD PTR [r8] + vpaddd ymm5,ymm5,YMMWORD PTR [r8+32] + vpaddd ymm6,ymm6,YMMWORD PTR [r8+64] + vpmaskmovd ymm8,ymm0,YMMWORD PTR [r8+96] + vpaddd ymm7,ymm7,ymm8 + +; First write 16 cols using regular mov and then maskmov for the rest < 8 cols +LSkipAccumulateOutputMasked32xNBlock1: + vmovdqu YMMWORD PTR [r8],ymm4 + vmovdqu YMMWORD PTR [r8+32],ymm5 + vmovdqu YMMWORD PTR [r8+64],ymm6 + vpmaskmovd YMMWORD PTR [r8+96],ymm0,ymm7 + jmp LExitProcessCountM1 + +;; write <= 24 elements ;; +LOutputMasked24xNBlock1: + sub rbp,16 + neg rbp + lea rcx,MlasMaskMoveTableAvx+8*4 + vmovdqu ymm0,YMMWORD PTR [rcx+rbp*4] + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutputMasked24xNBlock1 + vpaddd ymm4,ymm4,YMMWORD PTR [r8] + vpaddd ymm5,ymm5,YMMWORD PTR [r8+32] + vpmaskmovd ymm8,ymm0,YMMWORD PTR [r8+64] + vpaddd ymm6,ymm6,ymm8 + +; First write 16 cols using regular mov and then maskmov for the rest < 8 cols +LSkipAccumulateOutputMasked24xNBlock1: + vmovdqu YMMWORD PTR [r8],ymm4 + vmovdqu YMMWORD PTR [r8+32],ymm5 + vpmaskmovd YMMWORD PTR [r8+64],ymm0,ymm6 + jmp LExitProcessCountM1 + +;; write <= 16 elements ;; +LOutputMasked16xNBlock1: + add rbp,16 + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutputMasked16xNBlock1 + vpaddd ymm4,ymm4,YMMWORD PTR [r8] + +LSkipAccumulateOutputMasked16xNBlock1: + vmovdqu YMMWORD PTR [r8],ymm4 + add r8,8*4 ; advance matrix C by 8 columns + sub rbp,8 + +; at this point, rbp should be the value of num elements left to write +LOutputMasked8xNBlock1: + neg rbp + lea rcx,MlasMaskMoveTableAvx+8*4 + vmovdqu ymm0,YMMWORD PTR [rcx+rbp*4] + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutputMasked8xNBlock1 + vpmaskmovd ymm4,ymm0,YMMWORD PTR [r8] + vpaddd ymm5,ymm5,ymm4 + +LSkipAccumulateOutputMasked8xNBlock1: + vpmaskmovd YMMWORD PTR [r8],ymm0,ymm5 + jmp LExitProcessCountM1 + + ENDM ;++ ; @@ -870,7 +1083,8 @@ MlasGemmInt8KernelAvx2 MACRO ASigned, BSigned push_reg rdi push_reg r12 push_reg r13 - alloc_stack (GemmInt8KernelFrame.SavedR13) + push_reg r14 + alloc_stack (GemmInt8KernelFrame.SavedR14) save_xmm128 xmm6,GemmInt8KernelFrame.SavedXmm6 save_xmm128 xmm7,GemmInt8KernelFrame.SavedXmm7 save_xmm128 xmm8,GemmInt8KernelFrame.SavedXmm8 @@ -897,6 +1111,8 @@ MlasGemmInt8KernelAvx2 MACRO ASigned, BSigned mov r13,GemmInt8KernelFrame.ZeroPointB[rsp] vpcmpeqw ymm12,ymm12,ymm12 ; generate 256-bit word vector [0xFFFF] vpsrlw ymm12,ymm12,15 ; generate 256-bit word vector [0x0001] + lea r14,[r9*8] + lea r14,[r14*2] cmp DWORD PTR GemmInt8KernelFrame.PreviousP1Home[rsp],0 je CheckCountM4OrMore ; U8S8 AVX2 kernel requires extra registers @@ -941,10 +1157,11 @@ ExitKernel: movaps xmm13,GemmInt8KernelFrame.SavedXmm13[rsp] movaps xmm14,GemmInt8KernelFrame.SavedXmm14[rsp] movaps xmm15,GemmInt8KernelFrame.SavedXmm15[rsp] - add rsp,(GemmInt8KernelFrame.SavedR13) + add rsp,(GemmInt8KernelFrame.SavedR14) BEGIN_EPILOGUE + pop r14 pop r13 pop r12 pop rdi @@ -954,8 +1171,13 @@ ExitKernel: ret ProcessCountM1: + cmp DWORD PTR GemmInt8KernelFrame.PreviousP1Home[rsp],-1 + je ProcessCountM1AvxVnni ProcessCountM 1, ASigned, BSigned +ProcessCountM1AvxVnni: + ProcessCount1AvxVnni 1, ASigned, BSigned + ProcessCountM3: ProcessCountM 3, ASigned, BSigned diff --git a/onnxruntime/core/mlas/lib/amd64/mlasi.inc b/onnxruntime/core/mlas/lib/amd64/mlasi.inc index 2db3147168727..a4f58c1060a8a 100644 --- a/onnxruntime/core/mlas/lib/amd64/mlasi.inc +++ b/onnxruntime/core/mlas/lib/amd64/mlasi.inc @@ -93,6 +93,15 @@ ENDIF ENDM +EmitIfCount2EQ MACRO Count1, Value1, Count2, Value2, Statement + +IF (Count1 EQ Value1) AND (Count2 EQ Value2) + Statement +ENDIF + + ENDM + + ; ; Macro Description: ; diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S index af2a475ea0c59..9199d1ead475b 100644 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S @@ -28,16 +28,17 @@ Abstract: // .equ .LGemmInt8KernelFrame_type, -8 - .equ .LGemmInt8KernelFrame_SavedR13, 0 - .equ .LGemmInt8KernelFrame_SavedR12, 8 - .equ .LGemmInt8KernelFrame_SavedRbx, 16 - .equ .LGemmInt8KernelFrame_SavedRbp, 24 - .equ .LGemmInt8KernelFrame_ReturnAddress, 32 - .equ .LGemmInt8KernelFrame_ldc, 40 - .equ .LGemmInt8KernelFrame_RowSumBuffer, 48 - .equ .LGemmInt8KernelFrame_ColumnSumBuffer, 56 - .equ .LGemmInt8KernelFrame_ZeroPointB, 64 - .equ .LGemmInt8KernelFrame_ZeroMode, 72 + .equ .LGemmInt8KernelFrame_SavedR14, 0 + .equ .LGemmInt8KernelFrame_SavedR13, 8 + .equ .LGemmInt8KernelFrame_SavedR12, 16 + .equ .LGemmInt8KernelFrame_SavedRbx, 24 + .equ .LGemmInt8KernelFrame_SavedRbp, 32 + .equ .LGemmInt8KernelFrame_ReturnAddress, 40 + .equ .LGemmInt8KernelFrame_ldc, 48 + .equ .LGemmInt8KernelFrame_RowSumBuffer, 56 + .equ .LGemmInt8KernelFrame_ColumnSumBuffer, 64 + .equ .LGemmInt8KernelFrame_ZeroPointB, 72 + .equ .LGemmInt8KernelFrame_ZeroMode, 80 /*++ @@ -145,6 +146,44 @@ Implicit Arguments: .endm +/*++ +Macro Description: + + This macro generates the appropriate vpdp instruction based on the ASigned + and BSigned values. + +Arguments: + + ASigned - sign of A. + + BSigned - sign of B. + + reg1 - Output register for vpdp instruction + + reg2 - Second input register for vpdp instruction + + reg3 - First input register for vpdp instruction + +--*/ + + .macro VpdpYmmYmmYmm ASigned, BSigned, reg1, reg2, reg3 + + .if \ASigned\() == 1 + .if \BSigned\() == 1 + VpdpbssdYmmYmmYmm \reg1\(),\reg2\(),\reg3\() + .else + VpdpbsudYmmYmmYmm \reg1\(),\reg2\(),\reg3\() + .endif + .else + .if \BSigned\() == 1 + VpdpbusdYmmYmmYmm \reg1\(),\reg2\(),\reg3\() + .else + VpdpbuudYmmYmmYmm \reg1\(),\reg2\(),\reg3\() + .endif + .endif + + .endm + /*++ Macro Description: @@ -171,41 +210,21 @@ Implicit Arguments: --*/ - .macro MultiplyAccumulateRowAvxVnni ColumnCount, Vec1Reg, Vec2Reg, ASigned, BSigned - -.if \ASigned\() == 1 - .if \BSigned\() == 1 - .if \ColumnCount\() == 16 - VpdpbssdYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbssdYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - .else - VpdpbssdYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 - .endif - .else - .if \ColumnCount\() == 16 - VpdpbsudYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbsudYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - .else - VpdpbsudYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 - .endif - .endif -.else - .if \BSigned\() == 1 - .if \ColumnCount\() == 16 - VpdpbusdYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbusdYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - .else - VpdpbusdYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 - .endif - .else - .if \ColumnCount\() == 16 - VpdpbuudYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbuudYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - .else - VpdpbuudYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 - .endif - .endif -.endif + .macro MultiplyAccumulateRowAvxVnni ColumnCount, ASigned, BSigned, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + .if \ColumnCount\() == 32 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec1Reg\(), ymm2, ymm0 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec2Reg\(), ymm2, ymm1 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec3Reg\(), ymm2, ymm14 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec4Reg\(), ymm2, ymm15 + .endif + .if \ColumnCount\() == 16 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec1Reg\(), ymm2, ymm0 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec2Reg\(), ymm2, ymm1 + .endif + .if \ColumnCount\() == 8 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec2Reg\(), ymm2, ymm0 + .endif .endm @@ -244,18 +263,20 @@ Implicit Arguments: vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()] EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]" + EmitIfCount2EQ \ColumnCount\(), 32, \RowCount\(), 1, "vmovdqu ymm14,YMMWORD PTR [rsi+r14+\VectorOffset\()]" + EmitIfCount2EQ \ColumnCount\(), 32, \RowCount\(), 1, "vmovdqu ymm15,YMMWORD PTR [rsi+r14+\VectorOffset\()+32]" EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm4, ymm5, \ASigned\(), \BSigned\()" + EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), \ASigned\(), \BSigned\(), ymm4, ymm5, ymm6, ymm7" EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm6, ymm7, \ASigned\(), \BSigned\()" + EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), \ASigned\(), \BSigned\(), ymm6, ymm7" EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm8, ymm9, \ASigned\(), \BSigned\()" + EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), \ASigned\(), \BSigned\(), ymm8, ymm9" EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm10, ymm11, \ASigned\(), \BSigned\()" + EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), \ASigned\(), \BSigned\(), ymm10, ymm11" EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [r8+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm12, ymm13, \ASigned\(), \BSigned\()" + EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), \ASigned\(), \BSigned\(), ymm12, ymm13" EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm14, ymm15, \ASigned\(), \BSigned\()" + EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), \ASigned\(), \BSigned\(), ymm14, ymm15" .endm @@ -292,7 +313,7 @@ Implicit Arguments: mov rbp,rcx # reload row length remaining -.if (\ColumnCount\() == 16) && (\RowCount\() == 1) +.if (\ColumnCount\() >= 16) && (\RowCount\() == 1) sub rbp,4*4 jb .LProcessRemainingBlocks\@ @@ -527,24 +548,42 @@ Implicit Arguments: EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r11+12]" EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r11+16]" EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r11+20]" -.if \ColumnCount\() == 16 +.if \ColumnCount\() >= 16 +.if \ColumnCount\() == 32 vmovdqu ymm0,YMMWORD PTR [r12] vmovdqu ymm1,YMMWORD PTR [r12+32] - add r12,16*4 # advance ColumnSumBuffer by 16 columns + vmovdqu ymm14,YMMWORD PTR [r12+64] + vmovdqu ymm15,YMMWORD PTR [r12+96] +.else + vmovdqu ymm0,YMMWORD PTR [r12] + vmovdqu ymm1,YMMWORD PTR [r12+32] +.endif + add r12,\ColumnCount\()*4 # advance ColumnSumBuffer by 16/32 columns .else vmovdqu ymm1,YMMWORD PTR [r12] .endif test r13,r13 # per column zero points? jz .LSkipScaleByZeroPointB\@ -.if \ColumnCount\() == 16 +.if \ColumnCount\() >= 16 +.if \ColumnCount\() == 32 vmovdqu ymm2,YMMWORD PTR [r13] vmovdqu ymm3,YMMWORD PTR [r13+32] - add r13,16*4 # advance ZeroPointB by 16 columns + vmovdqu ymm12,YMMWORD PTR [r13+64] + vmovdqu ymm13,YMMWORD PTR [r13+96] +.else + vmovdqu ymm2,YMMWORD PTR [r13] + vmovdqu ymm3,YMMWORD PTR [r13+32] +.endif + add r13,\ColumnCount\()*4 # advance ZeroPointB by 16/32 columns .else vmovdqu ymm3,YMMWORD PTR [r13] .endif + EmitIfCount2EQ \RowCount\(), 1, \ColumnCount\(), 32, "vpmulld ymm6,ymm5,ymm12" + EmitIfCount2EQ \RowCount\(), 1, \ColumnCount\(), 32, "vpmulld ymm7,ymm5,ymm13" EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpmulld ymm4,ymm5,ymm2" EmitIfCountGE \RowCount\(), 1, "vpmulld ymm5,ymm5,ymm3" + EmitIfCount2EQ \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd ymm6,ymm14,ymm6" + EmitIfCount2EQ \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd ymm7,ymm15,ymm7" EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm0,ymm4" EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm1,ymm5" EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpmulld ymm6,ymm7,ymm2" @@ -570,6 +609,8 @@ Implicit Arguments: jmp .LAccumulatorsInitialized\@ .LSkipScaleByZeroPointB\@: + EmitIfCount2EQ \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd ymm6,ymm5,ymm14" + EmitIfCount2EQ \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd ymm7,ymm5,ymm15" EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0" EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1" EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0" @@ -777,6 +818,159 @@ Implicit Arguments: /*++ +Section Description: + This macro generates code to compute matrix multiplication for a single + row. When processing just one row, there are more ymm registers available + for us to unroll the main kernel further to benefit from better pipelining + the dot product instruction. +Arguments: None +Implicit Arguments: Same as ProcessCountM + +--*/ + + .macro ProcessCount1AvxVnni ASigned, BSigned + cmp r9,8 + jbe .LProcessRemainingCountN1\@ # num of cols <= 8?: process the tail + cmp r9,16 + jbe .LProcessNextColumnLoop16xN1\@ # num of cols <= 16?: process 16 at a time: + +.LProcessNextColumnLoop32xN1\@: # Ouptut look to process 32 cols at a time: + ProduceOutputBlock 32, 1, \ASigned\(), \BSigned\() + add rsi,r14 + sub r9,32 + jb .LOutputMasked32xNBlock1\@ # if numcols < 32 (& > 16), use write using masked output and exit + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput32xNBlock1\@ + vpaddd ymm4,ymm4,YMMWORD PTR [rdx] + vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32] + vpaddd ymm6,ymm6,YMMWORD PTR [rdx+64] + vpaddd ymm7,ymm7,YMMWORD PTR [rdx+96] + +.LSkipAccumulateOutput32xNBlock1\@: + vmovdqu YMMWORD PTR [rdx],ymm4 + vmovdqu YMMWORD PTR [rdx+32],ymm5 + vmovdqu YMMWORD PTR [rdx+64],ymm6 + vmovdqu YMMWORD PTR [rdx+96],ymm7 + add rdx,32*4 # advance matrix C by 32 columns + mov rdi,rbx # reload matrix A + cmp r9,0 + je .LExitProcessCountM1\@ + cmp r9,8 + jle .LProcessRemainingCountN1\@ # num of cols < 8 + cmp r9,16 + ja .LProcessNextColumnLoop32xN1\@ # num of cols > 16?: process 32 at a time: + +.LProcessNextColumnLoop16xN1\@: # num of cols > 8 and <= 16 + ProduceOutputBlock 16, 1, \ASigned\(), \BSigned\() + sub r9,16 + jb .LOutputMasked16xNBlock1\@ # if numcols < 16 (& > 8), use write using masked output and exit + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput16xNBlock1\@ + vpaddd ymm4,ymm4,YMMWORD PTR [rdx] + vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32] + +.LSkipAccumulateOutput16xNBlock1\@: + vmovdqu YMMWORD PTR [rdx],ymm4 + vmovdqu YMMWORD PTR [rdx+32],ymm5 + add rdx,16*4 # advance matrix C by 16 columns + mov rdi,rbx # reload matrix A + cmp r9,0 + je .LExitProcessCountM1\@ + cmp r9,8 + ja .LProcessNextColumnLoop16xN1\@ # num of cols > 8?: process 16 at a time: + +# Loop if num of cols <= 8 +.LProcessRemainingCountN1\@: + ProduceOutputBlock 8, 1, \ASigned\(), \BSigned\() + cmp r9,8 + jb .LOutputMasked8xNBlock1\@ # if numcols < 8, use write using masked output and exit + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput8xNBlock1\@ + vpaddd ymm5,ymm5,YMMWORD PTR [rdx] + +.LSkipAccumulateOutput8xNBlock1\@: + vmovdqu YMMWORD PTR [rdx],ymm5 + +.LExitProcessCountM1\@: # num of cols = 0, we are done + mov eax, 1 + jmp .LExitKernel + +## -- Section to write final tail of C matrix and exit -- ## +## write <= 32 elements ## +.LOutputMasked32xNBlock1\@: + add r9,32 + cmp r9,24 + jle .LOutputMasked24xNBlock1\@ + sub r9,24 + neg r9 + lea rdi,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] + vmovdqu ymm0,YMMWORD PTR [rdi+r9*4] + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutputMasked32xNBlock1\@ + vpaddd ymm4,ymm4,YMMWORD PTR [rdx] + vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32] + vpaddd ymm6,ymm6,YMMWORD PTR [rdx+64] + vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+96] + vpaddd ymm7,ymm7,ymm8 + +# First write 16 cols using regular mov and then maskmov for the rest < 8 cols +.LSkipAccumulateOutputMasked32xNBlock1\@: + vmovdqu YMMWORD PTR [rdx],ymm4 + vmovdqu YMMWORD PTR [rdx+32],ymm5 + vmovdqu YMMWORD PTR [rdx+64],ymm6 + vpmaskmovd YMMWORD PTR [rdx+96],ymm0,ymm7 + jmp .LExitProcessCountM1\@ + +## write <= 24 elements ## +.LOutputMasked24xNBlock1\@: + sub r9,16 + neg r9 + lea rdi,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] + vmovdqu ymm0,YMMWORD PTR [rdi+r9*4] + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutputMasked24xNBlock1\@ + vpaddd ymm4,ymm4,YMMWORD PTR [rdx] + vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32] + vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+64] + vpaddd ymm6,ymm6,ymm8 + +# First write 16 cols using regular mov and then maskmov for the rest < 8 cols +.LSkipAccumulateOutputMasked24xNBlock1\@: + vmovdqu YMMWORD PTR [rdx],ymm4 + vmovdqu YMMWORD PTR [rdx+32],ymm5 + vpmaskmovd YMMWORD PTR [rdx+64],ymm0,ymm6 + jmp .LExitProcessCountM1\@ + +## write <= 16 elements ## +.LOutputMasked16xNBlock1\@: + add r9,16 + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutputMasked16xNBlock1\@ + vpaddd ymm4,ymm4,YMMWORD PTR [rdx] + +.LSkipAccumulateOutputMasked16xNBlock1\@: + vmovdqu YMMWORD PTR [rdx],ymm4 + add rdx,8*4 # advance matrix C by 8 columns + sub r9,8 + +# at this point, r9 should be the value of num elements left to write +.LOutputMasked8xNBlock1\@: + neg r9 + lea rdi,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] + vmovdqu ymm0,YMMWORD PTR [rdi+r9*4] + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutputMasked8xNBlock1\@ + vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx] + vpaddd ymm5,ymm5,ymm4 + +.LSkipAccumulateOutputMasked8xNBlock1\@: + vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5 + jmp .LExitProcessCountM1\@ + + .endm + +/*++ + Routine Description: This routine is an inner kernel to compute matrix multiplication for a @@ -832,6 +1026,7 @@ Return Value: push rbx push r12 push r13 + push r14 mov DWORD PTR .LGemmInt8KernelFrame_type[rsp],eax mov rbx,rdi @@ -844,6 +1039,8 @@ Return Value: mov r13,.LGemmInt8KernelFrame_ZeroPointB[rsp] vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF] vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001] + lea rbp,[rcx*8] + lea r14,[rbp*2] cmp DWORD PTR .LGemmInt8KernelFrame_type[rsp],0 je .LCheckCountM4OrMore\@ # U8S8 AVX2 kernel requires extra registers @@ -873,8 +1070,13 @@ Return Value: ProcessCountM 6, \ASigned\(), \BSigned\() .LProcessCountM1\@: + cmp DWORD PTR .LGemmInt8KernelFrame_type[rsp],-1 + je .LProcessCountM1AvxVnni\@ ProcessCountM 1, \ASigned\(), \BSigned\() +.LProcessCountM1AvxVnni\@: + ProcessCount1AvxVnni \ASigned\(), \BSigned\() + .LProcessCountM3\@: ProcessCountM 3, \ASigned\(), \BSigned\() @@ -890,6 +1092,7 @@ Return Value: .LExitKernel: vzeroupper + pop r14 pop r13 pop r12 pop rbx diff --git a/onnxruntime/core/mlas/lib/x86_64/asmmacro.h b/onnxruntime/core/mlas/lib/x86_64/asmmacro.h index 7d7b3079a5132..7ef836c5701f3 100644 --- a/onnxruntime/core/mlas/lib/x86_64/asmmacro.h +++ b/onnxruntime/core/mlas/lib/x86_64/asmmacro.h @@ -97,6 +97,28 @@ Macro Description: .endm + +/*++ +Macro Description: + This macro conditionally emits the statement if Count1 is equal to Value1 + and Count2 is equal to Value2. +Arguments: + Count1 - Supplies the variable used in the comparison. + Value1 - Supplies the static used in the comparison. + Count2 - Supplies the variable used in the comparison. + Value2 - Supplies the static used in the comparison. + Statement - Supplies the statement to conditionally emit. +--*/ + + .macro EmitIfCount2EQ Count1, Value1, Count2, Value2, Statement + +.if (\Count1\() == \Value1\()) && (\Count2\() == \Value2\()) + \Statement\() +.endif + + .endm + + /*++ Macro Description: diff --git a/onnxruntime/test/mlas/bench/bench_qgemm.cpp b/onnxruntime/test/mlas/bench/bench_qgemm.cpp index 29a68f6aec6e6..42d1a590c71b9 100644 --- a/onnxruntime/test/mlas/bench/bench_qgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_qgemm.cpp @@ -82,6 +82,9 @@ static void QGemmSize(benchmark::internal::Benchmark* b) { b->ArgNames(qgemm_arg_names); // Args for "M", "N", "K", "Batch", "Threads" + b->Args({1, 512, 512, 1, 1}); + b->Args({1, 512, 1024, 1, 1}); + b->Args({1, 1024, 1024, 1, 1}); b->Args({384, 1024, 1024, 1, 4}); b->Args({384, 1024, 3072, 1, 4}); b->Args({384, 1024, 4096, 1, 4}); From 683e78c3e4e11e4c14347a983f9cd3eba1a06c98 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 14 Nov 2025 18:02:33 -0800 Subject: [PATCH 04/15] [webgpu] Add implementation of BiasGelu (#26560) ### Description Add implementation of BiasGelu --- .../contrib_ops/webgpu/bert/bias_gelu.cc | 95 +++++++++++++++++++ .../contrib_ops/webgpu/bert/bias_gelu.h | 38 ++++++++ .../webgpu/webgpu_contrib_kernels.cc | 2 + .../test/contrib_ops/element_wise_ops_test.cc | 2 +- 4 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/bias_gelu.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/bias_gelu.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/bias_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/bias_gelu.cc new file mode 100644 index 0000000000000..7cad1f4d06d0c --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/bias_gelu.cc @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/math/unary_elementwise_ops.h" +#include "contrib_ops/webgpu/bert/bias_gelu.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + BiasGelu, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + BiasGelu); + +Status BiasGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride); + const auto& y = shader.AddOutput("y", ShaderUsage::UseUniform); + + shader.AdditionalImplementation() << onnxruntime::webgpu::ErfImpl; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " var a = " << x.GetByOffset("global_idx") << ";\n"; + + // Add bias to input + if (bias_components_ == 1) { + shader.MainFunctionBody() << " let bias_offset = global_idx * 4;\n" + " a += x_value_t(" + << bias.GetByOffset("bias_offset % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") << ");\n"; + } else { + shader.MainFunctionBody() << " a += " << bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; + } + + // Apply GELU activation: 0.5 * a * (1.0 + erf(a * 0.7071067811865475)) + shader.MainFunctionBody() << y.SetByOffset("global_idx", onnxruntime::webgpu::GeluExpr); + + return Status::OK(); +} + +Status BiasGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto* bias = context.Input(1); + auto* output = context.Output(0, input->Shape()); + + uint32_t data_size = onnxruntime::narrow(output->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + const auto& input_shape = input->Shape(); + const auto& bias_shape = bias->Shape(); + + // Validate inputs + if (input_shape.NumDimensions() < 1 || bias_shape.NumDimensions() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BiasGelu: input must have at least 1 dimension and bias must be 1-dimensional."); + } + + if (input_shape.GetDims().back() != bias_shape.GetDims().back()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BiasGelu: bias must match the last dimension of input."); + } + + const auto vec_size = (data_size + 3) / 4; + uint32_t bias_size = onnxruntime::narrow(bias->Shape().Size()); + int bias_components = 1; + + if (bias_size % 4 == 0) { + bias_components = 4; + bias_size = bias_size / 4; + } + + BiasGeluProgram program{bias_components}; + program.AddInput({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}) + .AddOutput({output, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariable({vec_size}); + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/bias_gelu.h b/onnxruntime/contrib_ops/webgpu/bert/bias_gelu.h new file mode 100644 index 0000000000000..ac7cba249ca61 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/bias_gelu.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class BiasGeluProgram final : public Program { + public: + BiasGeluProgram(int bias_components) : Program{"BiasGelu"}, bias_components_{bias_components} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + int bias_components_; +}; + +class BiasGelu final : public WebGpuKernel { + public: + BiasGelu(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index e3573534f94b9..357eebee714d5 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -12,6 +12,7 @@ namespace webgpu { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv); @@ -42,6 +43,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc index c641103a74465..38659fbd9f2b9 100644 --- a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc +++ b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc @@ -109,7 +109,7 @@ TEST(BiasGeluTest, Float) { RunBiasGeluTestFloat({2, 2333}, {2333}); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) static void RunBiasGeluTestHalf(const std::vector& input_dims, const std::vector& bias_dims) { RandomValueGenerator random{2333}; std::vector input_data = random.Uniform(input_dims, -1.0f, 1.0f); From 4f600bbad4a578c72eeda3866920dbc775d86583 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 14 Nov 2025 18:18:21 -0800 Subject: [PATCH 05/15] webgpu csum - axis can be int32 or int64 (#26578) model that reproduces this: onnx-community/granite-4.0-h-350m-ONNX --- onnxruntime/core/providers/webgpu/math/cum_sum.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/cum_sum.cc b/onnxruntime/core/providers/webgpu/math/cum_sum.cc index bc4cd70a238fc..2c0bd6ad17e63 100644 --- a/onnxruntime/core/providers/webgpu/math/cum_sum.cc +++ b/onnxruntime/core/providers/webgpu/math/cum_sum.cc @@ -66,8 +66,12 @@ Status CumSum::ComputeInternal(ComputeContext& context) const { int64_t input_rank = input_shape.NumDimensions(); const auto* axis_tensor = context.Input(1); - const auto* axis_data = axis_tensor->Data(); - int64_t axis = static_cast(axis_data[0]); + int64_t axis; + if (axis_tensor->DataType() == DataTypeImpl::GetType()) { + axis = axis_tensor->Data()[0]; + } else { + axis = static_cast(axis_tensor->Data()[0]); + }; ORT_ENFORCE(-input_rank <= axis && axis < input_rank, "Axes attribute must be within range -input_rank <= axis < input_rank."); // Handle negative axis @@ -95,4 +99,4 @@ Status CumSum::ComputeInternal(ComputeContext& context) const { } } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime From 3f4ab76307b299b040bb1104600f078a41bb5e20 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 14 Nov 2025 23:08:38 -0800 Subject: [PATCH 06/15] fix zero size input for where op on webgpu (#26576) fixes an issue with edge-tam-encoder reported by hf --- onnxruntime/core/providers/webgpu/tensor/where.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index d7272ec525296..3560fba522cb8 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -126,6 +126,10 @@ Status Where::ComputeInternal(ComputeContext& context) const { TensorShape output_shape; ORT_RETURN_IF_ERROR(ComputeOutputShape(cond_shape, x_shape, y_shape, output_shape)); auto* output_tensor = context.Output(0, output_shape); + if (output_tensor->Shape().Size() == 0) { + return Status::OK(); + } + constexpr int component = 4; uint32_t vec_size = onnxruntime::narrow((output_shape.Size() + 3) / component); const auto is_broadcast = !(x_shape == y_shape && From b39e14432259d5f579f34c3e338f1f70aa368f01 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 16 Nov 2025 12:50:06 -0800 Subject: [PATCH 07/15] [web/webgpu] add `validationMode` to webgpu EP specific options. (#26581) ### Description add `validationMode` to webgpu EP specific options. - web: add implementation to pass the option - node: already supported --- js/common/lib/inference-session.ts | 14 ++++++++++++++ js/web/lib/wasm/session-options.ts | 5 +++++ 2 files changed, 19 insertions(+) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 09316966a2fd1..9503127006966 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -258,6 +258,20 @@ export declare namespace InferenceSession { */ forceCpuNodeNames?: readonly string[]; + /** + * Specify the validation mode for WebGPU execution provider. + * - 'disabled': Disable all validation. + * When used in Node.js, disable validation may cause process crash if WebGPU errors occur. Be cautious when using + * this mode. + * When used in web, this mode is equivalent to 'wgpuOnly'. + * - 'wgpuOnly': Perform WebGPU internal validation only. + * - 'basic': Perform basic validation including WebGPU internal validation. This is the default mode. + * - 'full': Perform full validation. This mode may have performance impact. Use it for debugging purpose. + * + * @default 'basic' + */ + validationMode?: 'disabled' | 'wgpuOnly' | 'basic' | 'full'; + /** * Specify an optional WebGPU device to be used by the WebGPU execution provider. */ diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index d9f3ad70f0c23..6b92f10768e45 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -129,6 +129,11 @@ const setExecutionProviders = async ( appendEpOption(epOptions, 'forceCpuNodeNames', names.join('\n'), allocs); } + + // set validation mode + if (webgpuOptions.validationMode) { + appendEpOption(epOptions, 'validationMode', webgpuOptions.validationMode, allocs); + } } const info = getInstance().webgpuRegisterDevice!(customDevice); From 1936d646b2b77d79731c0745e18c424577f9f472 Mon Sep 17 00:00:00 2001 From: xhcao Date: Tue, 18 Nov 2025 02:24:34 +0800 Subject: [PATCH 08/15] webgpu: fix dispatch size issue of Transpose operator (#26501) ### Description ### Motivation and Context --------- Co-authored-by: wp --- .../core/providers/webgpu/tensor/transpose.cc | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 04e3b09a8b790..cec321d0da80e 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -162,18 +162,16 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, uint32_t dispatch_z = 1; // This temporary workaround addresses a significant performance bottleneck - // (~12x slower) for the shape (3, 3, 2560, 1280) due to an issue with Intel's + // (~12x slower) for the input shape (1280, 2560, 3, 3) due to an issue with Intel's // GPU drivers. We manually normalize the dispatch group size to restore // performance. // // TODO: Revert this change once the driver issue is fixed. - if (context.AdapterInfo().vendor == std::string_view{"intel"}) { - // Only adjusted the dispatch size when rank is 4 yet. - if (rank == static_cast(4)) { - dispatch_x = ceil_div(input_shape[0] * input_shape[1], 2); - dispatch_y = ceil_div(input_shape[2], 4); - dispatch_z = ceil_div(input_shape[3], 8); - } + if (context.AdapterInfo().vendor == std::string_view{"intel"} && rank == 4) { + uint32_t dispatch_size = dispatch_x; + dispatch_x = 4; + dispatch_y = 8; + dispatch_z = ceil_div(dispatch_size, dispatch_x * dispatch_y); } program.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z); } From fe37372401abc0a92be00f7ede491aaab1d0e9ea Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 17 Nov 2025 11:00:07 -0800 Subject: [PATCH 09/15] [Lora] Adjust device dispatch according to the new OrtDevice defs (#26551) ### Description Check the device type and vendor to obtain data transfer. ### Motivation and Context Lora obtains DataTransfer based on the OrtMemoryInfo name which is now arbitrary. We should now rely on memory type and vendor definitions. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/session/lora_adapters.cc | 4 ++-- onnxruntime/test/lora/lora_test.cc | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/session/lora_adapters.cc b/onnxruntime/core/session/lora_adapters.cc index 124d748029fd4..de3acb22e12f8 100644 --- a/onnxruntime/core/session/lora_adapters.cc +++ b/onnxruntime/core/session/lora_adapters.cc @@ -53,11 +53,11 @@ void LoraAdapter::MemoryMap(const std::filesystem::path& file_path) { static std::unique_ptr GetDataTransfer(const OrtMemoryInfo& mem_info) { std::unique_ptr data_transfer; - if (mem_info.name == onnxruntime::CPU) { + if (mem_info.device.Type() == OrtDevice::CPU) { return data_transfer; } - if (mem_info.name == onnxruntime::CUDA) { + if (mem_info.device.Type() == OrtDevice::GPU && mem_info.device.Vendor() == OrtDevice::VendorIds::NVIDIA) { #if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) auto* cuda_provider_info = TryGetProviderInfo_CUDA(); if (cuda_provider_info != nullptr) { diff --git a/onnxruntime/test/lora/lora_test.cc b/onnxruntime/test/lora/lora_test.cc index ecfaf34c8a076..0c55cf45abcdf 100644 --- a/onnxruntime/test/lora/lora_test.cc +++ b/onnxruntime/test/lora/lora_test.cc @@ -216,7 +216,9 @@ TEST(LoraAdapterTest, VerifyDeviceCopy) { for (; begin != end; ++begin) { const auto& [_, param] = *begin; const auto& tensor_device = param.GetDeviceOrMapped().Get(); - ASSERT_EQ(0, strcmp(tensor_device.Location().name.c_str(), onnxruntime::CUDA)); + const auto& mem_info = tensor_device.Location(); + ASSERT_EQ(mem_info.device.Type(), OrtDevice::GPU); + ASSERT_EQ(mem_info.device.Vendor(), OrtDevice::VendorIds::NVIDIA); const auto& tensor_cpu = param.GetMapped().Get(); ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size()); From e8243313aae051f67c7b2b3671bf17efaa3a6698 Mon Sep 17 00:00:00 2001 From: Fengtu Wang <1148791151@qq.com> Date: Tue, 18 Nov 2025 06:02:01 +0800 Subject: [PATCH 10/15] Consistent with the configuration in the packaged cmake (#26104) Shared library path and include path are not same in release file `onnxruntime-linux-x64-gpu-1.22.0.tgz` . 1. The library file named libonnxruntime.so.1.22.0 in `onnxruntimeTargets-release.cmake` is at `lib64/libonnxruntime.so.1.22.0`, but file in `lib` path. 2. The include path in `onnxruntimeTargets.cmake` is in `${_IMPORT_PREFIX}/include/onnxruntime`, but path in `include` path. ### Description 1. image ` IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib64/libonnxruntime.so.1.22.0"` 2. image ` INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include/onnxruntime"` ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/24003#issue-2913355372 --- tools/ci_build/github/linux/copy_strip_binary.sh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tools/ci_build/github/linux/copy_strip_binary.sh b/tools/ci_build/github/linux/copy_strip_binary.sh index f5b4c38c85d4c..cff7c0a9f038f 100755 --- a/tools/ci_build/github/linux/copy_strip_binary.sh +++ b/tools/ci_build/github/linux/copy_strip_binary.sh @@ -18,8 +18,6 @@ EXIT_CODE=1 uname -a cd "$BINARY_DIR" mv installed/usr/local $ARTIFACT_NAME -mv $ARTIFACT_NAME/include/onnxruntime/* $ARTIFACT_NAME/include -rmdir $ARTIFACT_NAME/include/onnxruntime # Do not ship onnx_test_runner rm -rf $ARTIFACT_NAME/bin echo "Copy debug symbols in a separate file and strip the original binary." @@ -29,9 +27,6 @@ then strip -S $BINARY_DIR/$ARTIFACT_NAME/lib/$LIB_NAME # copy the CoreML EP header for macOS build (libs with .dylib ext) cp $SOURCE_DIR/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h $BINARY_DIR/$ARTIFACT_NAME/include -else - # Linux - mv $ARTIFACT_NAME/lib64 $ARTIFACT_NAME/lib fi # copy the README, licence and TPN From d55ade03897350f1f2b51b26ba298a789278dab6 Mon Sep 17 00:00:00 2001 From: Hans Date: Tue, 18 Nov 2025 07:43:15 +0800 Subject: [PATCH 11/15] [js/rn] Migrate to JSI implementation (#25764) ### Description close #16031 ### Motivation and Context Make React Native fully zero copy. --- .github/workflows/react_native.yml | 93 +-- js/react_native/android/CMakeLists.txt | 104 +++- js/react_native/android/build.gradle | 76 ++- .../android/src/androidTest/Readme.md | 1 - .../reactnative/FakeBlobModule.java | 36 -- .../reactnative/OnnxruntimeModuleTest.java | 203 ------- .../reactnative/TensorHelperTest.java | 565 ------------------ .../androidTest/res/raw/test_types_float.ort | Bin 1872 -> 0 bytes .../androidTest/res/raw/test_types_int32.ort | Bin 1872 -> 0 bytes .../androidTest/res/raw/test_types_int64.ort | Bin 1872 -> 0 bytes .../androidTest/res/raw/test_types_int8.ort | Bin 1872 -> 0 bytes .../androidTest/res/raw/test_types_uint8.ort | Bin 1872 -> 0 bytes .../android/src/main/cpp/cpp-adapter.cpp | 153 ++--- .../OnnxruntimeExtensionsDisabled.java | 5 +- .../OnnxruntimeExtensionsEnabled.java | 6 +- .../reactnative/OnnxruntimeJSIHelper.java | 70 --- .../reactnative/OnnxruntimeModule.java | 452 +------------- .../reactnative/OnnxruntimePackage.java | 1 - .../onnxruntime/reactnative/TensorHelper.java | 289 --------- js/react_native/cpp/AsyncWorker.h | 131 ++++ js/react_native/cpp/Env.h | 50 ++ .../cpp/InferenceSessionHostObject.cpp | 312 ++++++++++ .../cpp/InferenceSessionHostObject.h | 55 ++ js/react_native/cpp/JsiHelper.h | 116 ++++ js/react_native/cpp/JsiMain.cpp | 98 +++ js/react_native/cpp/JsiMain.h | 13 + js/react_native/cpp/JsiUtils.cpp | 32 + js/react_native/cpp/JsiUtils.h | 17 + js/react_native/cpp/SessionUtils.cpp | 450 ++++++++++++++ js/react_native/cpp/SessionUtils.h | 18 + js/react_native/cpp/TensorUtils.cpp | 236 ++++++++ js/react_native/cpp/TensorUtils.h | 26 + js/react_native/e2e/android/app/build.gradle | 9 - .../app/src/main/assets}/test_types_bool.onnx | Bin .../src/main/assets}/test_types_double.onnx | Bin .../app/src/main/assets/test_types_float.ort | Bin 0 -> 1824 bytes .../app/src/main/assets/test_types_int32.ort | Bin 0 -> 1824 bytes .../app/src/main/assets/test_types_int64.ort | Bin 0 -> 1824 bytes .../app/src/main/assets/test_types_int8.ort | Bin 0 -> 1824 bytes .../app/src/main/assets/test_types_uint8.ort | Bin 0 -> 1824 bytes .../MNISTDataHandler.java | 4 +- js/react_native/e2e/android/build.gradle | 6 +- js/react_native/e2e/android/gradle.properties | 4 +- js/react_native/e2e/ios/MNISTDataHandler.h | 2 +- js/react_native/e2e/ios/MNISTDataHandler.mm | 5 +- .../project.pbxproj | 28 + js/react_native/e2e/ios/test_types_bool.ort | Bin 0 -> 1824 bytes js/react_native/e2e/ios/test_types_double.ort | Bin 0 -> 1824 bytes js/react_native/e2e/metro.config.js | 1 + js/react_native/e2e/src/App.tsx | 230 +++---- js/react_native/e2e/src/BasicTypesTest.tsx | 387 ++++++++++++ js/react_native/e2e/src/MNISTTest.tsx | 137 +++++ .../src}/test_types_bool.onnx | Bin .../src}/test_types_double.onnx | Bin js/react_native/e2e/src/test_types_float.ort | Bin 0 -> 1824 bytes js/react_native/e2e/src/test_types_int32.ort | Bin 0 -> 1824 bytes js/react_native/e2e/src/test_types_int64.ort | Bin 0 -> 1824 bytes js/react_native/e2e/src/test_types_int8.ort | Bin 0 -> 1824 bytes .../e2e/src/test_types_models.readme.md | 21 + js/react_native/e2e/src/test_types_uint8.ort | Bin 0 -> 1824 bytes .../e2e/test/OnnxruntimeModuleExample.test.js | 47 +- js/react_native/ios/OnnxruntimeJSIHelper.h | 5 - js/react_native/ios/OnnxruntimeJSIHelper.mm | 90 --- js/react_native/ios/OnnxruntimeModule.h | 16 - js/react_native/ios/OnnxruntimeModule.mm | 429 +------------ .../project.pbxproj | 242 ++------ .../contents.xcworkspacedata | 10 - .../FakeRCTBlobManager.h | 23 - .../FakeRCTBlobManager.m | 47 -- .../ios/OnnxruntimeModuleTest/Info.plist | 22 - .../OnnxruntimeModuleTest.mm | 152 ----- .../ios/OnnxruntimeModuleTest/Readme.md | 1 - .../Resources/test_types_float.ort | Bin 1872 -> 0 bytes .../Resources/test_types_int32.ort | Bin 1872 -> 0 bytes .../Resources/test_types_int64.ort | Bin 1872 -> 0 bytes .../Resources/test_types_int8.ort | Bin 1872 -> 0 bytes .../Resources/test_types_uint8.ort | Bin 1872 -> 0 bytes .../OnnxruntimeModuleTest/TensorHelperTest.mm | 289 --------- js/react_native/ios/Podfile | 46 -- js/react_native/ios/TensorHelper.h | 57 -- js/react_native/ios/TensorHelper.mm | 275 --------- js/react_native/lib/api.ts | 48 ++ js/react_native/lib/backend.ts | 311 +++++----- js/react_native/lib/binding.ts | 123 +--- js/react_native/lib/index.ts | 14 +- .../onnxruntime-react-native.podspec | 17 +- js/react_native/package-lock.json | 25 +- js/react_native/package.json | 2 +- js/react_native/test_types_models.readme.md | 24 - 89 files changed, 2859 insertions(+), 3898 deletions(-) delete mode 100644 js/react_native/android/src/androidTest/Readme.md delete mode 100644 js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/FakeBlobModule.java delete mode 100644 js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/OnnxruntimeModuleTest.java delete mode 100644 js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java delete mode 100644 js/react_native/android/src/androidTest/res/raw/test_types_float.ort delete mode 100644 js/react_native/android/src/androidTest/res/raw/test_types_int32.ort delete mode 100644 js/react_native/android/src/androidTest/res/raw/test_types_int64.ort delete mode 100644 js/react_native/android/src/androidTest/res/raw/test_types_int8.ort delete mode 100644 js/react_native/android/src/androidTest/res/raw/test_types_uint8.ort delete mode 100644 js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeJSIHelper.java delete mode 100644 js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java create mode 100644 js/react_native/cpp/AsyncWorker.h create mode 100644 js/react_native/cpp/Env.h create mode 100644 js/react_native/cpp/InferenceSessionHostObject.cpp create mode 100644 js/react_native/cpp/InferenceSessionHostObject.h create mode 100644 js/react_native/cpp/JsiHelper.h create mode 100644 js/react_native/cpp/JsiMain.cpp create mode 100644 js/react_native/cpp/JsiMain.h create mode 100644 js/react_native/cpp/JsiUtils.cpp create mode 100644 js/react_native/cpp/JsiUtils.h create mode 100644 js/react_native/cpp/SessionUtils.cpp create mode 100644 js/react_native/cpp/SessionUtils.h create mode 100644 js/react_native/cpp/TensorUtils.cpp create mode 100644 js/react_native/cpp/TensorUtils.h rename js/react_native/{android/src/androidTest/res/raw => e2e/android/app/src/main/assets}/test_types_bool.onnx (100%) rename js/react_native/{android/src/androidTest/res/raw => e2e/android/app/src/main/assets}/test_types_double.onnx (100%) create mode 100644 js/react_native/e2e/android/app/src/main/assets/test_types_float.ort create mode 100644 js/react_native/e2e/android/app/src/main/assets/test_types_int32.ort create mode 100644 js/react_native/e2e/android/app/src/main/assets/test_types_int64.ort create mode 100644 js/react_native/e2e/android/app/src/main/assets/test_types_int8.ort create mode 100644 js/react_native/e2e/android/app/src/main/assets/test_types_uint8.ort create mode 100644 js/react_native/e2e/ios/test_types_bool.ort create mode 100644 js/react_native/e2e/ios/test_types_double.ort create mode 100644 js/react_native/e2e/src/BasicTypesTest.tsx create mode 100644 js/react_native/e2e/src/MNISTTest.tsx rename js/react_native/{ios/OnnxruntimeModuleTest/Resources => e2e/src}/test_types_bool.onnx (100%) rename js/react_native/{ios/OnnxruntimeModuleTest/Resources => e2e/src}/test_types_double.onnx (100%) create mode 100644 js/react_native/e2e/src/test_types_float.ort create mode 100644 js/react_native/e2e/src/test_types_int32.ort create mode 100644 js/react_native/e2e/src/test_types_int64.ort create mode 100644 js/react_native/e2e/src/test_types_int8.ort create mode 100644 js/react_native/e2e/src/test_types_models.readme.md create mode 100644 js/react_native/e2e/src/test_types_uint8.ort delete mode 100644 js/react_native/ios/OnnxruntimeJSIHelper.h delete mode 100644 js/react_native/ios/OnnxruntimeJSIHelper.mm delete mode 100644 js/react_native/ios/OnnxruntimeModule.xcworkspace/contents.xcworkspacedata delete mode 100644 js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.h delete mode 100644 js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.m delete mode 100644 js/react_native/ios/OnnxruntimeModuleTest/Info.plist delete mode 100644 js/react_native/ios/OnnxruntimeModuleTest/OnnxruntimeModuleTest.mm delete mode 100644 js/react_native/ios/OnnxruntimeModuleTest/Readme.md delete mode 100644 js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_float.ort delete mode 100644 js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int32.ort delete mode 100644 js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int64.ort delete mode 100644 js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_int8.ort delete mode 100644 js/react_native/ios/OnnxruntimeModuleTest/Resources/test_types_uint8.ort delete mode 100644 js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm delete mode 100644 js/react_native/ios/Podfile delete mode 100644 js/react_native/ios/TensorHelper.h delete mode 100644 js/react_native/ios/TensorHelper.mm create mode 100644 js/react_native/lib/api.ts delete mode 100644 js/react_native/test_types_models.readme.md diff --git a/.github/workflows/react_native.yml b/.github/workflows/react_native.yml index f827f3bc95456..343186b1aec8c 100644 --- a/.github/workflows/react_native.yml +++ b/.github/workflows/react_native.yml @@ -53,14 +53,11 @@ jobs: cp tools/ci_build/github/js/react_native_e2e_full_aar_build_settings.json ${{ runner.temp }}/.build_settings/build_settings.json python3 -m pip install --user -r ${{ github.workspace }}/tools/ci_build/requirements/pybind/requirements.txt - - python3 ${{ github.workspace }}/tools/ci_build/github/android/build_aar_package.py --build_dir ${{ runner.temp }} --config Release --android_sdk_path $ANDROID_SDK_ROOT --android_ndk_path $ANDROID_NDK_ROOT ${{ runner.temp }}/.build_settings/build_settings.json + + python3 ${{ github.workspace }}/tools/ci_build/github/android/build_aar_package.py --build_dir ${{ runner.temp }} --config Release --android_sdk_path $ANDROID_SDK_ROOT --android_ndk_path $ANDROID_NDK_ROOT ${{ runner.temp }}/.build_settings/build_settings.json # Copy the built artifacts to give folder for publishing - BASE_PATH=${{ runner.temp }}/aar_out/Release/com/microsoft/onnxruntime/onnxruntime-android/${OnnxRuntimeVersion} - cp ${BASE_PATH}/*.jar ${{ runner.temp }}/artifacts - cp ${BASE_PATH}/*.aar ${{ runner.temp }}/artifacts - cp ${BASE_PATH}/*.pom ${{ runner.temp }}/artifacts + cp -r ${{ runner.temp }}/aar_out/Release/com ${{ runner.temp }}/artifacts - name: Upload Android AAR Artifact uses: actions/upload-artifact@v5 @@ -109,10 +106,8 @@ jobs: - name: Copy AAR to React Native and E2E directories run: | - mkdir -p ${{ github.workspace }}/js/react_native/android/libs - cp ${{ runner.temp }}/android-full-aar/*.aar ${{ github.workspace }}/js/react_native/android/libs mkdir -p ${{ github.workspace }}/js/react_native/e2e/android/app/libs - cp ${{ runner.temp }}/android-full-aar/*.aar ${{ github.workspace }}/js/react_native/e2e/android/app/libs + cp -r ${{ runner.temp }}/android-full-aar/com ${{ github.workspace }}/js/react_native/e2e/android/app/libs - name: Install dependencies and bootstrap run: | @@ -141,10 +136,6 @@ jobs: with: ndk-version: 28.0.13004108 - - name: Run React Native Android Instrumented Tests - run: ./gradlew connectedDebugAndroidTest --stacktrace - working-directory: ${{ github.workspace }}/js/react_native/android - - name: Run React Native Detox Android e2e Tests run: | JEST_JUNIT_OUTPUT_FILE=${{ github.workspace }}/js/react_native/e2e/android-test-results.xml \ @@ -169,6 +160,15 @@ jobs: echo "Emulator PID file was expected to exist but does not." fi + - name: Upload Android Test Results + if: always() + uses: actions/upload-artifact@v5 + with: + name: android-test-results + path: | + ${{ github.workspace }}/js/react_native/e2e/android-test-results.xml + ${{ github.workspace }}/js/react_native/e2e/artifacts + react_native_ci_ios_build: name: React Native CI iOS Build runs-on: macos-14 @@ -211,62 +211,6 @@ jobs: name: ios_pod path: ${{ runner.temp }}/ios_pod - react_native_ci_ios_unit_tests: - name: React Native CI iOS Unit Tests - needs: react_native_ci_ios_build - runs-on: macos-14 - timeout-minutes: 90 - steps: - - name: Checkout repository - uses: actions/checkout@v5 - - - name: Download iOS pod artifact - uses: actions/download-artifact@v6 - with: - name: ios_pod - path: ${{ runner.temp }}/ios_pod - - - name: Use Xcode 15.3.0 - run: sudo xcode-select --switch /Applications/Xcode_15.3.0.app/Contents/Developer - - - name: Use Node.js 22.x - uses: actions/setup-node@v6 - with: - node-version: '22.x' - - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 - with: - vcpkg-version: '2025.06.13' - vcpkg-hash: 735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc - cmake-version: '3.31.8' - cmake-hash: 99cc9c63ae49f21253efb5921de2ba84ce136018abf08632c92c060ba91d552e0f6acc214e9ba8123dee0cf6d1cf089ca389e321879fd9d719a60d975bcffcc8 - add-cmake-to-path: 'true' - disable-terrapin: 'true' - - - name: Install dependencies and bootstrap - run: | - npm ci - working-directory: ${{ github.workspace }}/js - - run: npm ci - working-directory: ${{ github.workspace }}/js/common - - run: | - set -e -x - npm ci - npm run bootstrap-no-pods - working-directory: ${{ github.workspace }}/js/react_native - - - name: Pod install - run: | - set -e -x - ls ${{ runner.temp }}/ios_pod/onnxruntime-c - ORT_C_LOCAL_POD_PATH=${{ runner.temp }}/ios_pod/onnxruntime-c pod install --verbose - working-directory: ${{ github.workspace }}/js/react_native/ios - - - name: Run React Native iOS Instrumented Tests - run: | - /usr/bin/xcodebuild -sdk iphonesimulator -configuration Debug -workspace ${{ github.workspace }}/js/react_native/ios/OnnxruntimeModule.xcworkspace -scheme OnnxruntimeModuleTest -destination 'platform=iOS Simulator,name=iPhone 15,OS=17.4' test CODE_SIGNING_ALLOWED=NO - working-directory: ${{ github.workspace }}/js/react_native/ios - react_native_ci_ios_e2e_tests: name: React Native CI iOS E2E Tests needs: react_native_ci_ios_build @@ -314,7 +258,7 @@ jobs: npm ci npm run bootstrap-no-pods working-directory: ${{ github.workspace }}/js/react_native - + - name: Pod install for e2e tests run: | set -e -x @@ -331,3 +275,12 @@ jobs: --loglevel verbose \ --take-screenshots failing working-directory: ${{ github.workspace }}/js/react_native/e2e + + - name: Upload iOS Test Results + if: always() + uses: actions/upload-artifact@v5 + with: + name: ios-test-results + path: | + ${{ github.workspace }}/js/react_native/e2e/ios-test-results.xml + ${{ github.workspace }}/js/react_native/e2e/artifacts diff --git a/js/react_native/android/CMakeLists.txt b/js/react_native/android/CMakeLists.txt index 98f30daac6372..2f814e871ad77 100644 --- a/js/react_native/android/CMakeLists.txt +++ b/js/react_native/android/CMakeLists.txt @@ -1,37 +1,99 @@ -project(OnnxruntimeJSIHelper) +project(OnnxruntimeJSI) cmake_minimum_required(VERSION 3.9.0) -set (PACKAGE_NAME "onnxruntime-react-native") -set (BUILD_DIR ${CMAKE_SOURCE_DIR}/build) +set(PACKAGE_NAME "onnxruntime-react-native") +set(BUILD_DIR ${CMAKE_SOURCE_DIR}/build) set(CMAKE_VERBOSE_MAKEFILE ON) set(CMAKE_CXX_STANDARD 17) -file(TO_CMAKE_PATH "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi/jsi/jsi.cpp" libPath) +option(ORT_EXTENSIONS_ENABLED "Enable Ort Extensions" NO) +option(USE_NNAPI "Use NNAPI" YES) +option(USE_QNN "Use QNN" NO) + +file(GLOB libfbjni_link_DIRS "${BUILD_DIR}/fbjni-*.aar/jni/${ANDROID_ABI}") +file(GLOB libfbjni_include_DIRS "${BUILD_DIR}/fbjni-*-headers.jar/") + +file(GLOB onnxruntime_include_DIRS + "${BUILD_DIR}/onnxruntime-android-*.aar/headers") +file(GLOB onnxruntime_link_DIRS + "${BUILD_DIR}/onnxruntime-android-*.aar/jni/${ANDROID_ABI}/") + +if(ORT_EXTENSIONS_ENABLED) + add_definitions(-DORT_ENABLE_EXTENSIONS) +endif() + +if(USE_QNN) + add_definitions(-DUSE_QNN) +endif() + +if(USE_NNAPI) + add_definitions(-DUSE_NNAPI) +endif() + +file(TO_CMAKE_PATH + "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi/jsi/jsi.cpp" libPath) + +find_package(fbjni REQUIRED CONFIG) +find_package(ReactAndroid REQUIRED CONFIG) + +find_library( + onnxruntime-lib onnxruntime + PATHS ${onnxruntime_link_DIRS} + NO_CMAKE_FIND_ROOT_PATH) + +set(RN_INCLUDES + "${NODE_MODULES_DIR}/react-native/React" + "${NODE_MODULES_DIR}/react-native/React/Base" + "${NODE_MODULES_DIR}/react-native/ReactCommon" + "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi" + "${NODE_MODULES_DIR}/react-native/ReactCommon/callinvoker") + +if(${REACT_NATIVE_VERSION} VERSION_GREATER_EQUAL "0.76") + set(RN_LIBS + ReactAndroid::reactnative + ReactAndroid::jsi) +else() + list( + APPEND + RN_INCLUDES + "${NODE_MODULES_DIR}/react-native/ReactAndroid/src/main/java/com/facebook/react/turbomodule/core/jni" + ) + set(RN_LIBS + ReactAndroid::jsi + ReactAndroid::react_nativemodule_core + ReactAndroid::turbomodulejsijni) +endif() include_directories( - "${NODE_MODULES_DIR}/react-native/React" - "${NODE_MODULES_DIR}/react-native/React/Base" - "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi" -) + ../cpp + ${RN_INCLUDES} + ${onnxruntime_include_DIRS} + ${libfbjni_include_DIRS}) -add_library(onnxruntimejsihelper - SHARED - ${libPath} - src/main/cpp/cpp-adapter.cpp -) +add_library( + onnxruntimejsi SHARED + ${libPath} + src/main/cpp/cpp-adapter.cpp + ../cpp/JsiMain.cpp + ../cpp/InferenceSessionHostObject.cpp + ../cpp/JsiUtils.cpp + ../cpp/SessionUtils.cpp + ../cpp/TensorUtils.cpp) # Configure C++ 17 set_target_properties( - onnxruntimejsihelper PROPERTIES - CXX_STANDARD 17 - CXX_EXTENSIONS OFF - POSITION_INDEPENDENT_CODE ON -) + onnxruntimejsi + PROPERTIES CXX_STANDARD 17 + CXX_EXTENSIONS OFF + POSITION_INDEPENDENT_CODE ON) find_library(log-lib log) target_link_libraries( - onnxruntimejsihelper - ${log-lib} # <-- Logcat logger - android # <-- Android JNI core + onnxruntimejsi + ${onnxruntime-lib} + fbjni::fbjni + ${RN_LIBS} + ${log-lib} # <-- Logcat logger + android # <-- Android JNI core ) diff --git a/js/react_native/android/build.gradle b/js/react_native/android/build.gradle index 2f5b5adc7a1fa..41b43599a9af6 100644 --- a/js/react_native/android/build.gradle +++ b/js/react_native/android/build.gradle @@ -48,23 +48,22 @@ static def findNodeModules(baseDir) { def nodeModules = findNodeModules(projectDir); -def checkIfOrtExtensionsEnabled() { +def readPackageJsonField(field) { // locate user's project dir def reactnativeRootDir = project.rootDir.parentFile // get package.json file in root directory def packageJsonFile = new File(reactnativeRootDir, 'package.json') - // read field 'onnxruntimeExtensionsEnabled' if (packageJsonFile.exists()) { def packageJsonContents = packageJsonFile.getText() def packageJson = new groovy.json.JsonSlurper().parseText(packageJsonContents) - return packageJson.onnxruntimeExtensionsEnabled == "true" + return packageJson.get(field) } else { - logger.warn("Could not find package.json file in the expected directory: ${reactnativeRootDir}. ONNX Runtime Extensions will not be enabled.") + logger.warn("Could not find package.json file in the expected directory: ${reactnativeRootDir}. ${field} will not be enabled.") } - return false } -boolean ortExtensionsEnabled = checkIfOrtExtensionsEnabled() +boolean ortExtensionsEnabled = readPackageJsonField('onnxruntimeExtensionsEnabled') == "true" +boolean useQnn = readPackageJsonField('onnxruntimeUseQnn') == "true" def REACT_NATIVE_VERSION = ['node', '--print', "JSON.parse(require('fs').readFileSync(require.resolve('react-native/package.json'), 'utf-8')).version"].execute(null, rootDir).text.trim() def REACT_NATIVE_MINOR_VERSION = REACT_NATIVE_VERSION.split("\\.")[1].toInteger() @@ -85,9 +84,18 @@ android { cppFlags "-O2 -frtti -fexceptions -Wall -Wno-unused-variable -fstack-protector-all" if (REACT_NATIVE_MINOR_VERSION >= 71) { // fabricjni required c++_shared - arguments "-DANDROID_STL=c++_shared", "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}" + arguments "-DANDROID_STL=c++_shared", + "-DNODE_MODULES_DIR=${nodeModules}", + "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}", + "-DREACT_NATIVE_VERSION=${REACT_NATIVE_VERSION}", + "-DUSE_QNN=${useQnn}", + "-DUSE_NNAPI=${!useQnn}" } else { - arguments "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}" + arguments "-DNODE_MODULES_DIR=${nodeModules}", + "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}", + "-DREACT_NATIVE_VERSION=${REACT_NATIVE_VERSION}", + "-DUSE_QNN=${useQnn}", + "-DUSE_NNAPI=${!useQnn}" } abiFilters (*reactNativeArchitectures()) } @@ -119,6 +127,9 @@ android { "META-INF", "META-INF/**", "**/libjsi.so", + "**/libfbjni.so", + "**/libreact_nativemodule_core.so", + "**/libturbomodulejsijni.so" ] } @@ -147,6 +158,10 @@ android { } } } + + configurations { + extractLibs + } } repositories { @@ -217,10 +232,6 @@ repositories { "Ensure you have you installed React Native as a dependency in your project and try again." ) } - - flatDir { - dir 'libs' - } } dependencies { @@ -228,16 +239,47 @@ dependencies { implementation "com.facebook.react:react-android:"+ REACT_NATIVE_VERSION api "org.mockito:mockito-core:2.28.2" - androidTestImplementation "androidx.test:runner:1.5.2" - androidTestImplementation "androidx.test:rules:1.5.0" implementation "junit:junit:4.12" - androidTestImplementation "com.linkedin.dexmaker:dexmaker-mockito-inline-extended:2.28.1" + if (useQnn) { + extractLibs "com.microsoft.onnxruntime:onnxruntime-android-qnn:latest.integration@aar" + } else { + extractLibs "com.microsoft.onnxruntime:onnxruntime-android:latest.integration@aar" + } - implementation "com.microsoft.onnxruntime:onnxruntime-android:latest.integration@aar" + if (VersionNumber.parse(REACT_NATIVE_VERSION) < VersionNumber.parse("0.71")) { + extractLibs "com.facebook.fbjni:fbjni:+:headers" + extractLibs "com.facebook.fbjni:fbjni:+" + } // By default it will just include onnxruntime full aar package if (ortExtensionsEnabled) { implementation "com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.integration@aar" } -} \ No newline at end of file +} + +task extractLibs { + doLast { + configurations.extractLibs.files.each { + def file = it.absoluteFile + copy { + from zipTree(file) + into "$buildDir/$file.name" + include "**/*.h", "**/*.so" + } + } + } +} + +def nativeBuildDependsOn(dependsOnTask, variant) { + def buildTasks = tasks.findAll({ task -> + !task.name.contains("Clean") && (task.name.contains("externalNative") || task.name.contains("CMake")) }) + if (variant != null) { + buildTasks = buildTasks.findAll({ task -> task.name.contains(variant) }) + } + buildTasks.forEach { task -> task.dependsOn(dependsOnTask) } +} + +afterEvaluate { + nativeBuildDependsOn(extractLibs, null) +} diff --git a/js/react_native/android/src/androidTest/Readme.md b/js/react_native/android/src/androidTest/Readme.md deleted file mode 100644 index b0376602af908..0000000000000 --- a/js/react_native/android/src/androidTest/Readme.md +++ /dev/null @@ -1 +0,0 @@ -Please see [here](/js/react_native/test_types_models.readme.md) for information on the test models. diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/FakeBlobModule.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/FakeBlobModule.java deleted file mode 100644 index 82d063ad51e3f..0000000000000 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/FakeBlobModule.java +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package ai.onnxruntime.reactnative; - -import com.facebook.react.bridge.Arguments; -import com.facebook.react.bridge.JavaOnlyMap; -import com.facebook.react.bridge.ReactApplicationContext; -import com.facebook.react.bridge.ReadableMap; -import com.facebook.react.modules.blob.BlobModule; - -public class FakeBlobModule extends BlobModule { - - public FakeBlobModule(ReactApplicationContext context) { super(null); } - - @Override - public String getName() { - return "BlobModule"; - } - - public JavaOnlyMap testCreateData(byte[] bytes) { - String blobId = store(bytes); - JavaOnlyMap data = new JavaOnlyMap(); - data.putString("blobId", blobId); - data.putInt("offset", 0); - data.putInt("size", bytes.length); - return data; - } - - public byte[] testGetData(ReadableMap data) { - String blobId = data.getString("blobId"); - int offset = data.getInt("offset"); - int size = data.getInt("size"); - return resolve(blobId, offset, size); - } -} diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/OnnxruntimeModuleTest.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/OnnxruntimeModuleTest.java deleted file mode 100644 index b15b1a468ae29..0000000000000 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/OnnxruntimeModuleTest.java +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package ai.onnxruntime.reactnative; - -import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession; -import static org.mockito.Mockito.when; - -import ai.onnxruntime.TensorInfo; -import android.util.Base64; -import androidx.test.platform.app.InstrumentationRegistry; -import com.facebook.react.bridge.Arguments; -import com.facebook.react.bridge.CatalystInstance; -import com.facebook.react.bridge.JavaOnlyArray; -import com.facebook.react.bridge.JavaOnlyMap; -import com.facebook.react.bridge.ReactApplicationContext; -import com.facebook.react.bridge.ReadableArray; -import com.facebook.react.bridge.ReadableMap; -import com.facebook.react.bridge.WritableMap; -import com.facebook.react.modules.blob.BlobModule; -import java.io.ByteArrayOutputStream; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.FloatBuffer; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.mockito.MockitoSession; - -public class OnnxruntimeModuleTest { - private ReactApplicationContext reactContext = - new ReactApplicationContext(InstrumentationRegistry.getInstrumentation().getContext()); - - private FakeBlobModule blobModule; - - private static byte[] getInputModelBuffer(InputStream modelStream) throws Exception { - ByteArrayOutputStream byteBuffer = new ByteArrayOutputStream(); - - int bufferSize = 1024; - byte[] buffer = new byte[bufferSize]; - - int len; - while ((len = modelStream.read(buffer)) != -1) { - byteBuffer.write(buffer, 0, len); - } - - byte[] modelBuffer = byteBuffer.toByteArray(); - - return modelBuffer; - } - - @Before - public void setUp() { - blobModule = new FakeBlobModule(reactContext); - } - - @Test - public void getName() throws Exception { - OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext); - ortModule.blobModule = blobModule; - String name = "Onnxruntime"; - Assert.assertEquals(ortModule.getName(), name); - } - - @Test - public void onnxruntime_module() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext); - ortModule.blobModule = blobModule; - String sessionKey = ""; - - // test loadModel() - { - try (InputStream modelStream = - reactContext.getResources().openRawResource(ai.onnxruntime.reactnative.test.R.raw.test_types_float);) { - byte[] modelBuffer = getInputModelBuffer(modelStream); - - JavaOnlyMap options = new JavaOnlyMap(); - try { - ReadableMap resultMap = ortModule.loadModel(modelBuffer, options); - sessionKey = resultMap.getString("key"); - ReadableArray inputNames = resultMap.getArray("inputNames"); - ReadableArray outputNames = resultMap.getArray("outputNames"); - - Assert.assertEquals(inputNames.size(), 1); - Assert.assertEquals(inputNames.getString(0), "input"); - Assert.assertEquals(outputNames.size(), 1); - Assert.assertEquals(outputNames.getString(0), "output"); - } catch (Exception e) { - Assert.fail(e.getMessage()); - } - } - } - - int[] dims = new int[] {1, 5}; - float[] inputData = new float[] {1.0f, 2.0f, -3.0f, Float.MIN_VALUE, Float.MAX_VALUE}; - - // test run() - { - JavaOnlyMap inputDataMap = new JavaOnlyMap(); - { - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dimsArray = new JavaOnlyArray(); - for (int dim : dims) { - dimsArray.pushInt(dim); - } - inputTensorMap.putArray("dims", dimsArray); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeFloat); - - ByteBuffer buffer = ByteBuffer.allocate(5 * Float.BYTES).order(ByteOrder.nativeOrder()); - FloatBuffer floatBuffer = buffer.asFloatBuffer(); - for (float value : inputData) { - floatBuffer.put(value); - } - floatBuffer.rewind(); - inputTensorMap.putMap("data", blobModule.testCreateData(buffer.array())); - - inputDataMap.putMap("input", inputTensorMap); - } - - JavaOnlyArray outputNames = new JavaOnlyArray(); - outputNames.pushString("output"); - - JavaOnlyMap options = new JavaOnlyMap(); - options.putBoolean("encodeTensorData", true); - - try { - ReadableMap resultMap = ortModule.run(sessionKey, inputDataMap, outputNames, options); - - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeFloat); - ReadableMap data = outputMap.getMap("data"); - FloatBuffer buffer = - ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asFloatBuffer(); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f); - } - } catch (Exception e) { - Assert.fail(e.getMessage()); - } - } - - // test dispose - ortModule.dispose(sessionKey); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void onnxruntime_module_append_nnapi() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext); - ortModule.blobModule = blobModule; - String sessionKey = ""; - - // test loadModel() with nnapi ep options - - try (InputStream modelStream = - reactContext.getResources().openRawResource(ai.onnxruntime.reactnative.test.R.raw.test_types_float);) { - - byte[] modelBuffer = getInputModelBuffer(modelStream); - - // register with nnapi ep options - JavaOnlyMap options = new JavaOnlyMap(); - JavaOnlyArray epArray = new JavaOnlyArray(); - epArray.pushString("nnapi"); - options.putArray("executionProviders", epArray); - - try { - ReadableMap resultMap = ortModule.loadModel(modelBuffer, options); - sessionKey = resultMap.getString("key"); - ReadableArray inputNames = resultMap.getArray("inputNames"); - ReadableArray outputNames = resultMap.getArray("outputNames"); - - Assert.assertEquals(inputNames.size(), 1); - Assert.assertEquals(inputNames.getString(0), "input"); - Assert.assertEquals(outputNames.size(), 1); - Assert.assertEquals(outputNames.getString(0), "output"); - } catch (Exception e) { - Assert.fail(e.getMessage()); - } - } - ortModule.dispose(sessionKey); - } finally { - mockSession.finishMocking(); - } - } -} diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java deleted file mode 100644 index 72518488e6682..0000000000000 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java +++ /dev/null @@ -1,565 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package ai.onnxruntime.reactnative; - -import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession; -import static org.mockito.Mockito.when; - -import ai.onnxruntime.OnnxJavaType; -import ai.onnxruntime.OnnxTensor; -import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtSession; -import ai.onnxruntime.OrtUtil; -import ai.onnxruntime.TensorInfo; -import android.content.Context; -import android.util.Base64; -import androidx.test.filters.SmallTest; -import androidx.test.platform.app.InstrumentationRegistry; -import com.facebook.react.bridge.Arguments; -import com.facebook.react.bridge.JavaOnlyArray; -import com.facebook.react.bridge.JavaOnlyMap; -import com.facebook.react.bridge.ReactApplicationContext; -import com.facebook.react.bridge.ReadableMap; -import com.facebook.react.modules.blob.BlobModule; -import java.io.ByteArrayOutputStream; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; -import java.nio.LongBuffer; -import java.nio.ShortBuffer; -import java.util.HashMap; -import java.util.Map; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.mockito.MockitoSession; - -@SmallTest -public class TensorHelperTest { - private ReactApplicationContext reactContext = - new ReactApplicationContext(InstrumentationRegistry.getInstrumentation().getContext()); - - private OrtEnvironment ortEnvironment; - - private FakeBlobModule blobModule; - - @Before - public void setUp() { - ortEnvironment = OrtEnvironment.getEnvironment("TensorHelperTest"); - blobModule = new FakeBlobModule(reactContext); - } - - @Test - public void createInputTensor_float32() throws Exception { - OnnxTensor outputTensor = - OnnxTensor.createTensor(ortEnvironment, new float[] {Float.MIN_VALUE, 2.0f, Float.MAX_VALUE}); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(3); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeFloat); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(3 * 4).order(ByteOrder.nativeOrder()); - FloatBuffer dataFloatBuffer = dataByteBuffer.asFloatBuffer(); - dataFloatBuffer.put(Float.MIN_VALUE); - dataFloatBuffer.put(2.0f); - dataFloatBuffer.put(Float.MAX_VALUE); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getFloatBuffer().array(), outputTensor.getFloatBuffer().array(), 1e-6f); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createInputTensor_int8() throws Exception { - OnnxTensor outputTensor = OnnxTensor.createTensor(ortEnvironment, new byte[] {Byte.MIN_VALUE, 2, Byte.MAX_VALUE}); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(3); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeByte); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(3); - dataByteBuffer.put(Byte.MIN_VALUE); - dataByteBuffer.put((byte)2); - dataByteBuffer.put(Byte.MAX_VALUE); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8); - Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getByteBuffer().array(), outputTensor.getByteBuffer().array()); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createInputTensor_uint8() throws Exception { - OnnxTensor outputTensor = OnnxTensor.createTensor(ortEnvironment, ByteBuffer.wrap(new byte[] {0, 2, (byte)255}), - new long[] {3}, OnnxJavaType.UINT8); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(3); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeUnsignedByte); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(3); - dataByteBuffer.put((byte)0); - dataByteBuffer.put((byte)2); - dataByteBuffer.put((byte)255); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); - Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getByteBuffer().array(), outputTensor.getByteBuffer().array()); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createInputTensor_int32() throws Exception { - OnnxTensor outputTensor = - OnnxTensor.createTensor(ortEnvironment, new int[] {Integer.MIN_VALUE, 2, Integer.MAX_VALUE}); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(3); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeInt); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(3 * 4).order(ByteOrder.nativeOrder()); - IntBuffer dataIntBuffer = dataByteBuffer.asIntBuffer(); - dataIntBuffer.put(Integer.MIN_VALUE); - dataIntBuffer.put(2); - dataIntBuffer.put(Integer.MAX_VALUE); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32); - Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getIntBuffer().array(), outputTensor.getIntBuffer().array()); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createInputTensor_int64() throws Exception { - OnnxTensor outputTensor = - OnnxTensor.createTensor(ortEnvironment, new long[] {Long.MIN_VALUE, 15000000001L, Long.MAX_VALUE}); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(3); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeLong); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(3 * 8).order(ByteOrder.nativeOrder()); - LongBuffer dataLongBuffer = dataByteBuffer.asLongBuffer(); - dataLongBuffer.put(Long.MIN_VALUE); - dataLongBuffer.put(15000000001L); - dataLongBuffer.put(Long.MAX_VALUE); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); - Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getLongBuffer().array(), outputTensor.getLongBuffer().array()); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createInputTensor_double() throws Exception { - OnnxTensor outputTensor = - OnnxTensor.createTensor(ortEnvironment, new double[] {Double.MIN_VALUE, 1.8e+30, Double.MAX_VALUE}); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(3); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeDouble); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(3 * 8).order(ByteOrder.nativeOrder()); - DoubleBuffer dataDoubleBuffer = dataByteBuffer.asDoubleBuffer(); - dataDoubleBuffer.put(Double.MIN_VALUE); - dataDoubleBuffer.put(1.8e+30); - dataDoubleBuffer.put(Double.MAX_VALUE); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE); - Assert.assertEquals(outputTensor.getInfo().onnxType, - TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getDoubleBuffer().array(), outputTensor.getDoubleBuffer().array(), 1e-6f); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createInputTensor_bool() throws Exception { - OnnxTensor outputTensor = OnnxTensor.createTensor(ortEnvironment, new boolean[] {false, true}); - - JavaOnlyMap inputTensorMap = new JavaOnlyMap(); - - JavaOnlyArray dims = new JavaOnlyArray(); - dims.pushInt(2); - inputTensorMap.putArray("dims", dims); - - inputTensorMap.putString("type", TensorHelper.JsTensorTypeBool); - - ByteBuffer dataByteBuffer = ByteBuffer.allocate(2); - dataByteBuffer.put((byte)0); - dataByteBuffer.put((byte)1); - inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - - OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); - - Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL); - Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL); - Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); - Assert.assertArrayEquals(inputTensor.getByteBuffer().array(), outputTensor.getByteBuffer().array()); - - inputTensor.close(); - outputTensor.close(); - } - - @Test - public void createOutputTensor_bool() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_bool); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - boolean[] inputData = new boolean[] {true, false, false, true, false}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - Object tensorInput = OrtUtil.reshape(inputData, dims); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, tensorInput); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeBool); - ReadableMap data = outputMap.getMap("data"); - ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i) == 1, inputData[i]); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void createOutputTensor_double() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_double); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - double[] inputData = new double[] {1.0f, 2.0f, -3.0f, Double.MIN_VALUE, Double.MAX_VALUE}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - Object tensorInput = OrtUtil.reshape(inputData, dims); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, tensorInput); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeDouble); - ReadableMap data = outputMap.getMap("data"); - DoubleBuffer buffer = - ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asDoubleBuffer(); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void createOutputTensor_float() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_float); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - float[] inputData = new float[] {1.0f, 2.0f, -3.0f, Float.MIN_VALUE, Float.MAX_VALUE}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - Object tensorInput = OrtUtil.reshape(inputData, dims); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, tensorInput); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeFloat); - ReadableMap data = outputMap.getMap("data"); - FloatBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asFloatBuffer(); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void createOutputTensor_int8() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_int8); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - byte[] inputData = new byte[] {1, 2, -3, Byte.MAX_VALUE, Byte.MAX_VALUE}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - Object tensorInput = OrtUtil.reshape(inputData, dims); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, tensorInput); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeByte); - ReadableMap data = outputMap.getMap("data"); - ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i]); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void createOutputTensor_int32() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_int32); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - int[] inputData = new int[] {1, 2, -3, Integer.MIN_VALUE, Integer.MAX_VALUE}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - Object tensorInput = OrtUtil.reshape(inputData, dims); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, tensorInput); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeInt); - ReadableMap data = outputMap.getMap("data"); - IntBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asIntBuffer(); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i]); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void createOutputTensor_int64() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_int64); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - long[] inputData = new long[] {1, 2, -3, Long.MIN_VALUE, Long.MAX_VALUE}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - Object tensorInput = OrtUtil.reshape(inputData, dims); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, tensorInput); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeLong); - ReadableMap data = outputMap.getMap("data"); - LongBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asLongBuffer(); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i]); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - @Test - public void createOutputTensor_uint8() throws Exception { - MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); - try { - when(Arguments.createMap()).thenAnswer(i -> new JavaOnlyMap()); - when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); - - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - byte[] modelData = readBytesFromResourceFile(ai.onnxruntime.reactnative.test.R.raw.test_types_uint8); - OrtSession session = ortEnvironment.createSession(modelData, options); - - long[] dims = new long[] {1, 5}; - byte[] inputData = new byte[] {1, 2, -3, Byte.MAX_VALUE, Byte.MAX_VALUE}; - - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - ByteBuffer inputBuffer = ByteBuffer.wrap(inputData); - OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, inputBuffer, dims, OnnxJavaType.UINT8); - container.put(inputName, onnxTensor); - - OrtSession.Result result = session.run(container); - - ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - ReadableMap outputMap = resultMap.getMap("output"); - for (int i = 0; i < 2; ++i) { - Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); - } - Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeUnsignedByte); - ReadableMap data = outputMap.getMap("data"); - ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)); - for (int i = 0; i < 5; ++i) { - Assert.assertEquals(buffer.get(i), inputData[i]); - } - - OnnxValue.close(container); - } finally { - mockSession.finishMocking(); - } - } - - private byte[] readBytesFromResourceFile(int resourceId) throws Exception { - Context context = InstrumentationRegistry.getInstrumentation().getContext(); - InputStream inputStream = context.getResources().openRawResource(resourceId); - ByteArrayOutputStream byteBuffer = new ByteArrayOutputStream(); - - int bufferSize = 1024; - byte[] buffer = new byte[bufferSize]; - - int len; - while ((len = inputStream.read(buffer)) != -1) { - byteBuffer.write(buffer, 0, len); - } - - return byteBuffer.toByteArray(); - } -} diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_float.ort b/js/react_native/android/src/androidTest/res/raw/test_types_float.ort deleted file mode 100644 index e5c40742843d5a3cf54d5581cffcfb586ab14c8b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1872 zcmZ9MJ!lkB5Xax@nVeB$)F6jQNO47qutfYof`uQk2>5YmlpsMkE}LW_w|i_hA68x= zB33CZEG$JVM8wL^B3Th*dz#H)a!*dyV=0r`0N8dq8?@yVV-B_Ac}ykiQp<-&_o9!T5=j*we&`aR4Z9 zDvV}PIS2YUaC_fcd5xdw-K=k+#a*;LiAVKcSx)sQe+np0Tv{`xhm2l+Gw}*kQ|&=B zYKDm_Wp;i}^?VLcN^&3543k&}O_J1ylR;Cq`GfXL6|yr=5L6_#1jMJFq_;tdJ^89B zJpqjA%la(PWgnh;Cbdn5%ox3nWcBiMm0-+or(lMN9J45+&ja!5JLnz`fSEnH-N1Ij zQr3g|aL}IY3Vo`FCOQQRUoH460d}_LBg^TUv|}Ux98T`Ardxgj z(Qo2pBi?sxMmpntZS!62_xQSxA}8JTkL%R>tGg)0dc)# zX2m0suNrYO?>FwHaZ+3Gt7#BvZ=);j7;Ck-?$^Uw5--JbsUNI0f+Va5QCeLzrpfno zyvi51{#!b7@JHI>6r=yTYUsyNwB|RGU?mJz{WPhDQ5em$Wv090)vt?hkJB3*psMEM zkP@w^hUPizW62MrG)SWAq8}}+))ev%e$AIev6^$oU*}w3y6PWczvE1PR1;N^>iMfM W<{&W=@j20aT$r4U=923Fd;AB>i48*l diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_int32.ort b/js/react_native/android/src/androidTest/res/raw/test_types_int32.ort deleted file mode 100644 index 6135c9a4aca7cba9ad1433547f37c08f9beb7cf8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1872 zcmZ9MJ!lkB5Xax@nVeB$)F6jQNO47qutfYof`uQk2>5YmlpsMkE}LW_w|i_hA68x= zB33CZEG$JVM8wLdx(Pr-^4mVCzakkQL;MqYtxsy%2% z%`j1=OwZ?3&*xxDiSDDCVG_%5lO*-wXwZ~x{-C|7LN@aRK}BLqu<+EASkB@px1DVQN5$1InT=V9UM9`p_e;F&$S-avPR zE9*ghIM$i;3Vo`FCNhN=;2BtRT&moyzHrEbdk;;0XzzC zpr^3#fBen&VUq`YY<1jT6J776n2YF=dedQ8`fcm0ATd4Riz4fP!{cFkn#Iu_e0t?4 z0R1LTI^=yvXQeaT*EU~ys`DOO?@?s(UVdGUdU5pD;!fbaW_8ZI33P|sgDyOWd(_%z zk#g_gX&3Tc)GkQ-MbZo6T2}k9X?rX^{Rc-JYVtU_j9=8e%yQH-(DfZ$m0e&;<6UNVYSRt_zG<2G!Ma6KblqF z@(67EL(A{d$({Sl{bxA%*p*+#r+in_^(forOFwP;v#(2EVv#BPZJz7H;o$C|Zk$t` z%BOqM9mdOPW0}R5tywga`jM?T^_+mc*1`F#Y zJ1ZQKY}JU9dB1Tljg#7fUrmEZXB%B{$MDtSx?c}#NxT%#rGBv12$HZKL}_)=m?rn> zaFs7?{kL@FV2`xLDMtTw)zFWlXw7dV!Acmc`e{-PqcEE1$V_*|t6vx29;a_`fU4S$ zLrS!y8rtWqk0n2h(jbYdi+;4QT2sh7@U>qe#cIzTex13#bk#q?dB;qCR1;K@>iJh; W%t2yA!gHkkxG*^x%_Y_U@Awb%i49Ev diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_int64.ort b/js/react_native/android/src/androidTest/res/raw/test_types_int64.ort deleted file mode 100644 index a9892d9ec598d8bdec1ccb684c49062d3d3f9bbe..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1872 zcmZ8hJ!lkB5T4aDS);~N4>?3aiYrot<-{K(Som`m0e=pS5+n%6W|LdEdwXm)KUQ8L zB33CZEG$JVM8wL z98qO$n{%?~a{x+F?!%OEBvybflGulnl2mMd$$qJVK64F0K4QxN`PdWnHc-YM*Q`lh z14Kr2p9OR|f)#s4YM+cr6}?XAdgfeZFq+#b$QUB478U680D1Wxcn=4Fxjn}2z;*(c z?ty(caF2E!eX@rxbOtnmIe_Q5RKIn3Eee;jI8pU9S&8SP6mB7(zG^OV!o5czo&+{v zX8`$s{5A8ii3@vdZMeNIY>b(Xxd@xFywh>(yKQ}SXt+G`Q=TfdL25?JKD#~dwk9$JA3N^L-8vd>?lo=Is^q+cR;+fjHd)craS>6u1J|bIL<-vme&y zTb=-Hf3VEAbb9Ce8UGn3Zba<Kn~`_82RjrvFPHAfOKL9c^Ln1F#Ov166=r zFEO*^p>WL_NxB$x?qx|DHiJeMDei6L;*P-=CheddhiS5uEM!5n)``-%9jUC*66xZ5 zI$q|JoBu5XIap8h#W9Beb&WVklv)cqX|xhYt3j4FVil`JY{_gcp8dM`{y4tDBC2ve z4l&URYjB=*A4@^3vM5!JR-l@zAw%AQpYugwEa%+uw{fm7x%x-2KX9f%g{1OP-Mk8s VgNUJ!&k5(_!t}ISNE`p}@gMt%4Nm|7 diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_int8.ort b/js/react_native/android/src/androidTest/res/raw/test_types_int8.ort deleted file mode 100644 index f1bf199e488e1059eaf043a7bb9caa8ce5167673..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1872 zcmZ8hJ7^SP5T4aDS);~N4>?3aiYrot<-`XPEPR|rz{jCcf&}5%Y;p^C{~nvo!^$f} z#43e_g{6pvh*()!iG@`vv67S)A{HV6gXj1C4|BHzbMyFTzM0v7{ws^f^&2x+A!NT4 z!9KJCByEvRaAjEJ^PgU?0Nn#qR=0b-lGS%Xw!pOaf~Rh_;xL*zaT5A0;)qxTGjAqV z^AH9D@;KO?ee0PTKOuL+zXeHb-}Xct)_-L&*&pp0Fyn|tpU4n0hViS(E1(){mt<1L z5mnZ@nUg)A1EUn_K2#Y;Vg=kq68msel8VhQ*)LVlGuII0Beo1C9($tP2A6S0n>8tG zV385sX8~P~z+%rx?V~ZNqSpys&&*W@qPd-dj3KgWxdM3}Ok92k-opX#+#aKMpgV#~ z_rN|J^Ne~OeX@rxWCm`6=fIrfQvKHDwJ2Q9;zZTcWF?-DQn-b9`l@N(yKQ}SNVq)Wla%$p!R5j9IE$k@@bQ+P z0O&X3sDr%k(6Q3G)@z&3Jl1>k@?reC9QNYqt;MaudClrL^J>r??hth1IozYx zJ_||r4jy;G>>_u8+Ak!%K%A@n47$U;YjGT|vi3S;9(S}4m-qObNp|=4VOrMX+XHz% zy5K&trn`qv7Ms0t-#E`#_{{e?Jn((oeVDgb&~MMg6$j#U3%15n$SYtwr#u8U`{As7 z%M)PRA2jnVo!t3;#(xF}A9MK?_?Yi%x*ln}eCnrVRKG61L?2U*+C0~X!-0E%@;Ik( zN*})y|9g_8r_EV76Z@gfIQDGn8_9h37%QEo|4Sbrpbq9c+QQif@H%)NTm=*B zC3coLB-*Tzq>Dl4UY4X`GiYRy;@L(n?ihSw(hk~jm?lffLKZ}8ohXglk;)n^kuJWc z!(~3P`QI{-1AC$`jxqeNYs5jK)LPI{K?AOKj$MFpoQI-2~ zh>4b1gZr%eSPEj5MX74E0@Yj%8S)PJ++h-9x#teQjk&($>L0=Rz)XP(3FV`@c@-iD U5kn%LBksqA>1nl)HvZq^Kii27M*si- diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_uint8.ort b/js/react_native/android/src/androidTest/res/raw/test_types_uint8.ort deleted file mode 100644 index 9f5310803323a5cd6e3db7cdae24041143608da2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1872 zcmZ8hJ7^S96uqNkGDeN7F0zP(6kDVS!-^jySom3sK-5K}1PQ`2J4pt1-z<~K$I2~4 z#43e_g{6py45k~RiQiHA?W0DyvMD5 z9wXhme7p;+;_JIW?H7_>BF@!*4&CwIvpkMhS$hNHAn#}oFVFD>lkD9)VC(S>LA(H6 z@IJC8dt>jPEH`c5H~PN9XZGi4-~PDsFz?yd@7cr^`r>p8;K9|BXTS_#*C~&{&3?F6 zzV8XZ_6N=Er5ktm&-l+_;^RPm89wH_nyyFME}!~Y8P%_gFEPL>M{S<#!|}k~Cxbj^ zFiRiziT|^7C2z0b;?w3@djDCBW6!3(8JW)>6Q$Gizw{vjY5?ESHs(G68$bi70OWp! zJ4+r4ZLT@Z7Q^=aJk6q|u%5??bDINkC*X_HR@h3SEL~0)@-SX+$63;fRbFq3bg-X} zm-*!8zojn+_C!w{WB9MDCt<48df3k5)g)dE^Q@kzL@i=U<_6-~uZ!=E;~OlXD(}Y; zCR$+)-e=v%a+s(*&Q!e_s-?AvA@9J?`$b_a|FMq0g>`)ysDBLeeQOF;M5+MQ&8rYO Vj2H^}obY~Jnx0k*S^fVx{sZOF4GjPQ diff --git a/js/react_native/android/src/main/cpp/cpp-adapter.cpp b/js/react_native/android/src/main/cpp/cpp-adapter.cpp index d75a2f9c99d8b..50434b71ec2ed 100644 --- a/js/react_native/android/src/main/cpp/cpp-adapter.cpp +++ b/js/react_native/android/src/main/cpp/cpp-adapter.cpp @@ -1,126 +1,43 @@ +#include "JsiMain.h" +#include +#include +#include +#include #include #include -#include using namespace facebook; -typedef u_int8_t byte; - -std::string jstring2string(JNIEnv* env, jstring jStr) { - if (!jStr) return ""; - - jclass stringClass = env->GetObjectClass(jStr); - jmethodID getBytes = env->GetMethodID(stringClass, "getBytes", "(Ljava/lang/String;)[B"); - const auto stringJbytes = (jbyteArray)env->CallObjectMethod(jStr, getBytes, env->NewStringUTF("UTF-8")); - - auto length = (size_t)env->GetArrayLength(stringJbytes); - jbyte* pBytes = env->GetByteArrayElements(stringJbytes, nullptr); - - std::string ret = std::string((char*)pBytes, length); - env->ReleaseByteArrayElements(stringJbytes, pBytes, JNI_ABORT); - - env->DeleteLocalRef(stringJbytes); - env->DeleteLocalRef(stringClass); - return ret; -} - -byte* getBytesFromBlob(JNIEnv* env, jobject instanceGlobal, const std::string& blobId, int offset, int size) { - if (!env) throw std::runtime_error("JNI Environment is gone!"); - - // get java class - jclass clazz = env->GetObjectClass(instanceGlobal); - // get method in java class - jmethodID getBufferJava = env->GetMethodID(clazz, "getBlobBuffer", "(Ljava/lang/String;II)[B"); - // call method - auto jstring = env->NewStringUTF(blobId.c_str()); - auto boxedBytes = (jbyteArray)env->CallObjectMethod(instanceGlobal, - getBufferJava, - // arguments - jstring, - offset, - size); - env->DeleteLocalRef(jstring); - - jboolean isCopy = true; - jbyte* bytes = env->GetByteArrayElements(boxedBytes, &isCopy); - env->DeleteLocalRef(boxedBytes); - return reinterpret_cast(bytes); -}; - -std::string createBlob(JNIEnv* env, jobject instanceGlobal, byte* bytes, size_t size) { - if (!env) throw std::runtime_error("JNI Environment is gone!"); - - // get java class - jclass clazz = env->GetObjectClass(instanceGlobal); - // get method in java class - jmethodID getBufferJava = env->GetMethodID(clazz, "createBlob", "([B)Ljava/lang/String;"); - // call method - auto byteArray = env->NewByteArray(size); - env->SetByteArrayRegion(byteArray, 0, size, reinterpret_cast(bytes)); - auto blobId = (jstring)env->CallObjectMethod(instanceGlobal, getBufferJava, byteArray); - env->DeleteLocalRef(byteArray); - - return jstring2string(env, blobId); +static std::shared_ptr env; + +class OnnxruntimeModule + : public jni::JavaClass { + public: + static constexpr auto kJavaDescriptor = + "Lai/onnxruntime/reactnative/OnnxruntimeModule;"; + + static void registerNatives() { + javaClassStatic()->registerNatives( + {makeNativeMethod("nativeInstall", + OnnxruntimeModule::nativeInstall), + makeNativeMethod("nativeCleanup", + OnnxruntimeModule::nativeCleanup)}); + } + + private: + static void nativeInstall(jni::alias_ref thiz, + jlong jsContextNativePointer, + jni::alias_ref + jsCallInvokerHolder) { + auto runtime = reinterpret_cast(jsContextNativePointer); + auto jsCallInvoker = jsCallInvokerHolder->cthis()->getCallInvoker(); + env = onnxruntimejsi::install(*runtime, jsCallInvoker); + } + + static void nativeCleanup(jni::alias_ref thiz) { env.reset(); } }; -extern "C" JNIEXPORT void JNICALL -Java_ai_onnxruntime_reactnative_OnnxruntimeJSIHelper_nativeInstall(JNIEnv* env, jclass _, jlong jsiPtr, jobject instance) { - auto jsiRuntime = reinterpret_cast(jsiPtr); - - auto& runtime = *jsiRuntime; - - auto instanceGlobal = env->NewGlobalRef(instance); - - auto resolveArrayBuffer = jsi::Function::createFromHostFunction(runtime, - jsi::PropNameID::forAscii(runtime, "jsiOnnxruntimeResolveArrayBuffer"), - 1, - [=](jsi::Runtime& runtime, - const jsi::Value& thisValue, - const jsi::Value* arguments, - size_t count) -> jsi::Value { - if (count != 1) { - throw jsi::JSError(runtime, "jsiOnnxruntimeResolveArrayBuffer(..) expects one argument (object)!"); - } - - jsi::Object data = arguments[0].asObject(runtime); - auto blobId = data.getProperty(runtime, "blobId").asString(runtime); - auto offset = data.getProperty(runtime, "offset").asNumber(); - auto size = data.getProperty(runtime, "size").asNumber(); - - auto bytes = getBytesFromBlob(env, instanceGlobal, blobId.utf8(runtime), offset, size); - - size_t totalSize = size - offset; - jsi::Function arrayBufferCtor = runtime.global().getPropertyAsFunction(runtime, "ArrayBuffer"); - jsi::Object o = arrayBufferCtor.callAsConstructor(runtime, (int)totalSize).getObject(runtime); - jsi::ArrayBuffer buf = o.getArrayBuffer(runtime); - memcpy(buf.data(runtime), reinterpret_cast(bytes), totalSize); - - return buf; - }); - runtime.global().setProperty(runtime, "jsiOnnxruntimeResolveArrayBuffer", std::move(resolveArrayBuffer)); - - auto storeArrayBuffer = jsi::Function::createFromHostFunction(runtime, - jsi::PropNameID::forAscii(runtime, "jsiOnnxruntimeStoreArrayBuffer"), - 1, - [=](jsi::Runtime& runtime, - const jsi::Value& thisValue, - const jsi::Value* arguments, - size_t count) -> jsi::Value { - if (count != 1) { - throw jsi::JSError(runtime, "jsiOnnxruntimeStoreArrayBuffer(..) expects one argument (object)!"); - } - - auto arrayBuffer = arguments[0].asObject(runtime).getArrayBuffer(runtime); - auto size = arrayBuffer.size(runtime); - - std::string blobId = createBlob(env, instanceGlobal, arrayBuffer.data(runtime), size); - - jsi::Object result(runtime); - auto blobIdString = jsi::String::createFromUtf8(runtime, blobId); - result.setProperty(runtime, "blobId", blobIdString); - result.setProperty(runtime, "offset", jsi::Value(0)); - result.setProperty(runtime, "size", jsi::Value(static_cast(size))); - return result; - }); - runtime.global().setProperty(runtime, "jsiOnnxruntimeStoreArrayBuffer", std::move(storeArrayBuffer)); +JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) { + return jni::initialize( + vm, [] { OnnxruntimeModule::registerNatives(); }); } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsDisabled.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsDisabled.java index de4c880981881..cacc382e29230 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsDisabled.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsDisabled.java @@ -3,14 +3,13 @@ package ai.onnxruntime.reactnative; -import ai.onnxruntime.OrtSession.SessionOptions; import android.util.Log; class OnnxruntimeExtensions { - public void registerOrtExtensionsIfEnabled(SessionOptions sessionOptions) { + static public String getLibraryPath() { Log.i("OnnxruntimeExtensions", "ORT Extensions is not enabled in the current configuration. If you want to enable this support, " + "please add \"onnxruntimeEnableExtensions\": \"true\" in your project root directory package.json."); - return; + return null; } } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsEnabled.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsEnabled.java index 9bbf41c8f1671..d41163fdb53e9 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsEnabled.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsEnabled.java @@ -3,12 +3,10 @@ package ai.onnxruntime.reactnative; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.extensions.OrtxPackage; class OnnxruntimeExtensions { - public void registerOrtExtensionsIfEnabled(SessionOptions sessionOptions) throws OrtException { - sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath()); + static public String getLibraryPath() { + return OrtxPackage.getLibraryPath(); } } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeJSIHelper.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeJSIHelper.java deleted file mode 100644 index 93b37df0768b4..0000000000000 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeJSIHelper.java +++ /dev/null @@ -1,70 +0,0 @@ -package ai.onnxruntime.reactnative; - -import androidx.annotation.NonNull; -import com.facebook.react.bridge.JavaScriptContextHolder; -import com.facebook.react.bridge.ReactApplicationContext; -import com.facebook.react.bridge.ReactContextBaseJavaModule; -import com.facebook.react.bridge.ReactMethod; -import com.facebook.react.module.annotations.ReactModule; -import com.facebook.react.modules.blob.BlobModule; - -@ReactModule(name = OnnxruntimeJSIHelper.NAME) -public class OnnxruntimeJSIHelper extends ReactContextBaseJavaModule { - public static final String NAME = "OnnxruntimeJSIHelper"; - - private static ReactApplicationContext reactContext; - protected BlobModule blobModule; - - public OnnxruntimeJSIHelper(ReactApplicationContext context) { - super(context); - reactContext = context; - } - - @Override - @NonNull - public String getName() { - return NAME; - } - - public void checkBlobModule() { - if (blobModule == null) { - blobModule = getReactApplicationContext().getNativeModule(BlobModule.class); - if (blobModule == null) { - throw new RuntimeException("BlobModule is not initialized"); - } - } - } - - @ReactMethod(isBlockingSynchronousMethod = true) - public boolean install() { - try { - System.loadLibrary("onnxruntimejsihelper"); - JavaScriptContextHolder jsContext = getReactApplicationContext().getJavaScriptContextHolder(); - nativeInstall(jsContext.get(), this); - return true; - } catch (Exception exception) { - return false; - } - } - - public byte[] getBlobBuffer(String blobId, int offset, int size) { - checkBlobModule(); - byte[] bytes = blobModule.resolve(blobId, offset, size); - blobModule.remove(blobId); - if (bytes == null) { - throw new RuntimeException("Failed to resolve Blob #" + blobId + "! Not found."); - } - return bytes; - } - - public String createBlob(byte[] buffer) { - checkBlobModule(); - String blobId = blobModule.store(buffer); - if (blobId == null) { - throw new RuntimeException("Failed to create Blob!"); - } - return blobId; - } - - public static native void nativeInstall(long jsiPointer, OnnxruntimeJSIHelper instance); -} diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java index 496db5a6087e6..c362e6ad71bbe 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java @@ -3,65 +3,21 @@ package ai.onnxruntime.reactnative; -import ai.onnxruntime.OnnxTensor; -import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtLoggingLevel; -import ai.onnxruntime.OrtSession; -import ai.onnxruntime.OrtSession.Result; -import ai.onnxruntime.OrtSession.RunOptions; -import ai.onnxruntime.OrtSession.SessionOptions; -import ai.onnxruntime.providers.NNAPIFlags; -import android.net.Uri; +import java.util.Map; +import java.util.HashMap; import android.os.Build; -import android.util.Log; import androidx.annotation.NonNull; import androidx.annotation.RequiresApi; -import com.facebook.react.bridge.Arguments; -import com.facebook.react.bridge.LifecycleEventListener; -import com.facebook.react.bridge.Promise; +import com.facebook.react.bridge.JavaScriptContextHolder; +import com.facebook.react.bridge.ReactMethod; import com.facebook.react.bridge.ReactApplicationContext; import com.facebook.react.bridge.ReactContextBaseJavaModule; -import com.facebook.react.bridge.ReactMethod; -import com.facebook.react.bridge.ReadableArray; -import com.facebook.react.bridge.ReadableMap; -import com.facebook.react.bridge.ReadableType; -import com.facebook.react.bridge.WritableArray; -import com.facebook.react.bridge.WritableMap; -import com.facebook.react.modules.blob.BlobModule; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.io.Reader; -import java.math.BigInteger; -import java.util.Collections; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; +import com.facebook.react.turbomodule.core.CallInvokerHolderImpl; @RequiresApi(api = Build.VERSION_CODES.N) -public class OnnxruntimeModule extends ReactContextBaseJavaModule implements LifecycleEventListener { +public class OnnxruntimeModule extends ReactContextBaseJavaModule { private static ReactApplicationContext reactContext; - private static OrtEnvironment ortEnvironment = OrtEnvironment.getEnvironment(); - private static Map sessionMap = new HashMap<>(); - - private static BigInteger nextSessionId = new BigInteger("0"); - private static String getNextSessionKey() { - String key = nextSessionId.toString(); - nextSessionId = nextSessionId.add(BigInteger.valueOf(1)); - return key; - } - - protected BlobModule blobModule; - public OnnxruntimeModule(ReactApplicationContext context) { super(context); reactContext = context; @@ -73,393 +29,37 @@ public String getName() { return "Onnxruntime"; } - public void checkBlobModule() { - if (blobModule == null) { - blobModule = getReactApplicationContext().getNativeModule(BlobModule.class); - if (blobModule == null) { - throw new RuntimeException("BlobModule is not initialized"); - } - } - } + native void nativeInstall(long jsiPointer, CallInvokerHolderImpl jsCallInvokerHolder); - /** - * React native binding API to load a model using given uri. - * - * @param uri a model file location - * @param options onnxruntime session options - * @param promise output returning back to react native js - * @note the value provided to `promise` includes a key representing the session. - * when run() is called, the key must be passed into the first parameter. - */ - @ReactMethod - public void loadModel(String uri, ReadableMap options, Promise promise) { - try { - WritableMap resultMap = loadModel(uri, options); - promise.resolve(resultMap); - } catch (Exception e) { - promise.reject("Failed to load model \"" + uri + "\": " + e.getMessage(), e); - } - } + native void nativeCleanup(); - /** - * React native binding API to load a model using blob object that data stored in BlobModule. - * - * @param data the blob object - * @param options onnxruntime session options - * @param promise output returning back to react native js - * @note the value provided to `promise` includes a key representing the session. - * when run() is called, the key must be passed into the first parameter. - */ - @ReactMethod - public void loadModelFromBlob(ReadableMap data, ReadableMap options, Promise promise) { - try { - checkBlobModule(); - String blobId = data.getString("blobId"); - byte[] bytes = blobModule.resolve(blobId, data.getInt("offset"), data.getInt("size")); - blobModule.remove(blobId); - WritableMap resultMap = loadModel(bytes, options); - promise.resolve(resultMap); - } catch (Exception e) { - promise.reject("Failed to load model from buffer: " + e.getMessage(), e); - } - } - - /** - * React native binding API to dispose a session. - * - * @param key session key representing a session given at loadModel() - * @param promise output returning back to react native js - */ - @ReactMethod - public void dispose(String key, Promise promise) { - try { - dispose(key); - promise.resolve(null); - } catch (OrtException e) { - promise.reject("Failed to dispose session: " + e.getMessage(), e); - } + @Override + public void invalidate() { + super.invalidate(); + nativeCleanup(); } /** - * React native binding API to run a model using given uri. - * - * @param key session key representing a session given at loadModel() - * @param input an input tensor - * @param output an output names to be returned - * @param options onnxruntime run options - * @param promise output returning back to react native js + * Install onnxruntime JSI API */ - @ReactMethod - public void run(String key, ReadableMap input, ReadableArray output, ReadableMap options, Promise promise) { + @ReactMethod(isBlockingSynchronousMethod = true) + public boolean install() { try { - WritableMap resultMap = run(key, input, output, options); - promise.resolve(resultMap); + System.loadLibrary("onnxruntimejsi"); + JavaScriptContextHolder jsContext = getReactApplicationContext().getJavaScriptContextHolder(); + CallInvokerHolderImpl jsCallInvokerHolder = + (CallInvokerHolderImpl) getReactApplicationContext().getCatalystInstance().getJSCallInvokerHolder(); + nativeInstall(jsContext.get(), jsCallInvokerHolder); + return true; } catch (Exception e) { - promise.reject("Fail to inference: " + e.getMessage(), e); - } - } - - /** - * Load a model from raw resource directory. - * - * @param uri uri parameter from react native loadModel() - * @param options onnxruntime session options - * @return model loading information, such as key, input names, and output names - */ - public WritableMap loadModel(String uri, ReadableMap options) throws Exception { - return loadModelImpl(uri, null, options); - } - - /** - * Load a model from buffer. - * - * @param modelData the model data buffer - * @param options onnxruntime session options - * @return model loading information, such as key, input names, and output names - */ - public WritableMap loadModel(byte[] modelData, ReadableMap options) throws Exception { - return loadModelImpl("", modelData, options); - } - - /** - * Load model implementation method for either from model path or model data buffer. - * - * @param uri uri parameter from react native loadModel() - * @param modelData model data buffer - * @param options onnxruntime session options - * @return model loading information map, such as key, input names, and output names - */ - private WritableMap loadModelImpl(String uri, byte[] modelData, ReadableMap options) throws Exception { - OrtSession ortSession; - SessionOptions sessionOptions = parseSessionOptions(options); - - // optional call for registering custom ops when ort extensions enabled - OnnxruntimeExtensions ortExt = new OnnxruntimeExtensions(); - ortExt.registerOrtExtensionsIfEnabled(sessionOptions); - - if (modelData != null && modelData.length > 0) { - // load model via model data array - ortSession = ortEnvironment.createSession(modelData, sessionOptions); - } else if (uri.startsWith("file://") || uri.startsWith("/")) { - // load model from local - if (uri.startsWith("file://")) { - uri = uri.substring(7); - } - ortSession = ortEnvironment.createSession(uri, sessionOptions); - } else { - // load model via model path string uri - InputStream modelStream = - reactContext.getApplicationContext().getContentResolver().openInputStream(Uri.parse(uri)); - Reader reader = new BufferedReader(new InputStreamReader(modelStream)); - byte[] modelArray = new byte[modelStream.available()]; - modelStream.read(modelArray); - modelStream.close(); - ortSession = ortEnvironment.createSession(modelArray, sessionOptions); - } - - String key = getNextSessionKey(); - sessionMap.put(key, ortSession); - - WritableMap resultMap = Arguments.createMap(); - resultMap.putString("key", key); - WritableArray inputNames = Arguments.createArray(); - for (String inputName : ortSession.getInputNames()) { - inputNames.pushString(inputName); - } - resultMap.putArray("inputNames", inputNames); - WritableArray outputNames = Arguments.createArray(); - for (String outputName : ortSession.getOutputNames()) { - outputNames.pushString(outputName); - } - resultMap.putArray("outputNames", outputNames); - - return resultMap; - } - - /** - * Dispose a model using given key. - * - * @param key a session key representing the session given at loadModel() - */ - public void dispose(String key) throws OrtException { - OrtSession ortSession = sessionMap.get(key); - if (ortSession != null) { - ortSession.close(); - sessionMap.remove(key); - } - } - - /** - * Run a model using given uri. - * - * @param key a session key representing the session given at loadModel() - * @param input an input tensor - * @param output an output names to be returned - * @param options onnxruntime run options - * @return inference result - */ - public WritableMap run(String key, ReadableMap input, ReadableArray output, ReadableMap options) throws Exception { - OrtSession ortSession = sessionMap.get(key); - if (ortSession == null) { - throw new Exception("Model is not loaded."); - } - - RunOptions runOptions = parseRunOptions(options); - - checkBlobModule(); - - long startTime = System.currentTimeMillis(); - Map feed = new HashMap<>(); - Iterator iterator = ortSession.getInputNames().iterator(); - Result result = null; - try { - while (iterator.hasNext()) { - String inputName = iterator.next(); - - ReadableMap inputMap = input.getMap(inputName); - if (inputMap == null) { - throw new Exception("Can't find input: " + inputName); - } - - OnnxTensor onnxTensor = TensorHelper.createInputTensor(blobModule, inputMap, ortEnvironment); - feed.put(inputName, onnxTensor); - } - - Set requestedOutputs = null; - if (output.size() > 0) { - requestedOutputs = new HashSet<>(); - for (int i = 0; i < output.size(); ++i) { - requestedOutputs.add(output.getString(i)); - } - } - - long duration = System.currentTimeMillis() - startTime; - Log.d("Duration", "createInputTensor: " + duration); - - startTime = System.currentTimeMillis(); - if (requestedOutputs != null) { - result = ortSession.run(feed, requestedOutputs, runOptions); - } else { - result = ortSession.run(feed, runOptions); - } - duration = System.currentTimeMillis() - startTime; - Log.d("Duration", "inference: " + duration); - - startTime = System.currentTimeMillis(); - WritableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); - duration = System.currentTimeMillis() - startTime; - Log.d("Duration", "createOutputTensor: " + duration); - - return resultMap; - - } finally { - OnnxValue.close(feed); - if (result != null) { - result.close(); - } - } - } - - private static final Map graphOptimizationLevelTable = - Stream - .of(new Object[][] { - {"disabled", SessionOptions.OptLevel.NO_OPT}, - {"basic", SessionOptions.OptLevel.BASIC_OPT}, - {"extended", SessionOptions.OptLevel.EXTENDED_OPT}, - // {"layout", SessionOptions.OptLevel.LAYOUT_OPT}, - {"all", SessionOptions.OptLevel.ALL_OPT}, - }) - .collect(Collectors.toMap(p -> (String)p[0], p -> (SessionOptions.OptLevel)p[1])); - - private static final Map executionModeTable = - Stream - .of(new Object[][] {{"sequential", SessionOptions.ExecutionMode.SEQUENTIAL}, - {"parallel", SessionOptions.ExecutionMode.PARALLEL}}) - .collect(Collectors.toMap(p -> (String)p[0], p -> (SessionOptions.ExecutionMode)p[1])); - - private SessionOptions parseSessionOptions(ReadableMap options) throws OrtException { - SessionOptions sessionOptions = new SessionOptions(); - - if (options.hasKey("intraOpNumThreads")) { - int intraOpNumThreads = options.getInt("intraOpNumThreads"); - if (intraOpNumThreads > 0 && intraOpNumThreads < Integer.MAX_VALUE) { - sessionOptions.setIntraOpNumThreads(intraOpNumThreads); - } - } - - if (options.hasKey("interOpNumThreads")) { - int interOpNumThreads = options.getInt("interOpNumThreads"); - if (interOpNumThreads > 0 && interOpNumThreads < Integer.MAX_VALUE) { - sessionOptions.setInterOpNumThreads(interOpNumThreads); - } + return false; } - - if (options.hasKey("graphOptimizationLevel")) { - String graphOptimizationLevel = options.getString("graphOptimizationLevel"); - if (graphOptimizationLevelTable.containsKey(graphOptimizationLevel)) { - sessionOptions.setOptimizationLevel(graphOptimizationLevelTable.get(graphOptimizationLevel)); - } - } - - if (options.hasKey("enableCpuMemArena")) { - boolean enableCpuMemArena = options.getBoolean("enableCpuMemArena"); - sessionOptions.setCPUArenaAllocator(enableCpuMemArena); - } - - if (options.hasKey("enableMemPattern")) { - boolean enableMemPattern = options.getBoolean("enableMemPattern"); - sessionOptions.setMemoryPatternOptimization(enableMemPattern); - } - - if (options.hasKey("executionMode")) { - String executionMode = options.getString("executionMode"); - if (executionModeTable.containsKey(executionMode)) { - sessionOptions.setExecutionMode(executionModeTable.get(executionMode)); - } - } - - if (options.hasKey("executionProviders")) { - ReadableArray executionProviders = options.getArray("executionProviders"); - for (int i = 0; i < executionProviders.size(); ++i) { - String epName = null; - ReadableMap epOptions = null; - if (executionProviders.getType(i) == ReadableType.String) { - epName = executionProviders.getString(i); - } else { - epOptions = executionProviders.getMap(i); - epName = epOptions.getString("name"); - } - if (epName.equals("nnapi")) { - EnumSet flags = EnumSet.noneOf(NNAPIFlags.class); - if (epOptions != null) { - if (epOptions.hasKey("useFP16") && epOptions.getBoolean("useFP16")) { - flags.add(NNAPIFlags.USE_FP16); - } - if (epOptions.hasKey("useNCHW") && epOptions.getBoolean("useNCHW")) { - flags.add(NNAPIFlags.USE_NCHW); - } - if (epOptions.hasKey("cpuDisabled") && epOptions.getBoolean("cpuDisabled")) { - flags.add(NNAPIFlags.CPU_DISABLED); - } - if (epOptions.hasKey("cpuOnly") && epOptions.getBoolean("cpuOnly")) { - flags.add(NNAPIFlags.CPU_ONLY); - } - } - sessionOptions.addNnapi(flags); - } else if (epName.equals("xnnpack")) { - sessionOptions.addXnnpack(Collections.emptyMap()); - } else if (epName.equals("cpu")) { - continue; - } else { - throw new OrtException("Unsupported execution provider: " + epName); - } - } - } - - if (options.hasKey("logId")) { - String logId = options.getString("logId"); - sessionOptions.setLoggerId(logId); - } - - if (options.hasKey("logSeverityLevel")) { - int logSeverityLevel = options.getInt("logSeverityLevel"); - sessionOptions.setSessionLogLevel(OrtLoggingLevel.mapFromInt(logSeverityLevel)); - } - - return sessionOptions; } - private RunOptions parseRunOptions(ReadableMap options) throws OrtException { - RunOptions runOptions = new RunOptions(); - - if (options.hasKey("logSeverityLevel")) { - int logSeverityLevel = options.getInt("logSeverityLevel"); - runOptions.setLogLevel(OrtLoggingLevel.mapFromInt(logSeverityLevel)); - } - - if (options.hasKey("tag")) { - String tag = options.getString("tag"); - runOptions.setRunTag(tag); - } - - return runOptions; - } - - @Override - public void onHostResume() {} - @Override - public void onHostPause() {} - - @Override - public void onHostDestroy() { - for (String key : sessionMap.keySet()) { - try { - dispose(key); - } catch (Exception e) { - Log.e("onHostDestroy", "Failed to dispose session: " + key, e); - } - } - sessionMap.clear(); + public Map getConstants() { + final Map constants = new HashMap(); + constants.put("ORT_EXTENSIONS_PATH", OnnxruntimeExtensions.getLibraryPath()); + return constants; } } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimePackage.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimePackage.java index bb4386a0953f3..9171641e6e68a 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimePackage.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimePackage.java @@ -22,7 +22,6 @@ public class OnnxruntimePackage implements ReactPackage { public List createNativeModules(@NonNull ReactApplicationContext reactContext) { List modules = new ArrayList<>(); modules.add(new OnnxruntimeModule(reactContext)); - modules.add(new OnnxruntimeJSIHelper(reactContext)); return modules; } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java deleted file mode 100644 index 63cddace36640..0000000000000 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package ai.onnxruntime.reactnative; - -import ai.onnxruntime.OnnxJavaType; -import ai.onnxruntime.OnnxTensor; -import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtSession; -import ai.onnxruntime.OrtUtil; -import ai.onnxruntime.TensorInfo; -import android.util.Base64; -import com.facebook.react.bridge.Arguments; -import com.facebook.react.bridge.ReadableArray; -import com.facebook.react.bridge.ReadableMap; -import com.facebook.react.bridge.WritableArray; -import com.facebook.react.bridge.WritableMap; -import com.facebook.react.modules.blob.BlobModule; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; -import java.nio.LongBuffer; -import java.nio.ShortBuffer; -import java.util.Iterator; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -public class TensorHelper { - /** - * Supported tensor data type - */ - public static final String JsTensorTypeBool = "bool"; - public static final String JsTensorTypeByte = "int8"; - public static final String JsTensorTypeUnsignedByte = "uint8"; - public static final String JsTensorTypeShort = "int16"; - public static final String JsTensorTypeInt = "int32"; - public static final String JsTensorTypeLong = "int64"; - public static final String JsTensorTypeFloat = "float32"; - public static final String JsTensorTypeDouble = "float64"; - public static final String JsTensorTypeString = "string"; - - /** - * It creates an input tensor from a map passed by react native js. - * 'data' is blob object and the buffer is stored in BlobModule. It first resolve it and creates a tensor. - */ - public static OnnxTensor createInputTensor(BlobModule blobModule, ReadableMap inputTensor, - OrtEnvironment ortEnvironment) throws Exception { - // shape - ReadableArray dimsArray = inputTensor.getArray("dims"); - long[] dims = new long[dimsArray.size()]; - for (int i = 0; i < dimsArray.size(); ++i) { - dims[i] = dimsArray.getInt(i); - } - - // type - TensorInfo.OnnxTensorType tensorType = getOnnxTensorType(inputTensor.getString("type")); - - // data - OnnxTensor onnxTensor = null; - if (tensorType == TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - ReadableArray values = inputTensor.getArray("data"); - String[] buffer = new String[values.size()]; - for (int i = 0; i < values.size(); ++i) { - buffer[i] = values.getString(i); - } - onnxTensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - } else { - ReadableMap data = inputTensor.getMap("data"); - String blobId = data.getString("blobId"); - byte[] bytes = blobModule.resolve(blobId, data.getInt("offset"), data.getInt("size")); - blobModule.remove(blobId); - ByteBuffer values = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()); - onnxTensor = createInputTensor(tensorType, dims, values, ortEnvironment); - } - - return onnxTensor; - } - - /** - * It creates an output map from an output tensor. - * a data array is store in BlobModule. - */ - public static WritableMap createOutputTensor(BlobModule blobModule, OrtSession.Result result) throws Exception { - WritableMap outputTensorMap = Arguments.createMap(); - - Iterator> iterator = result.iterator(); - while (iterator.hasNext()) { - Map.Entry entry = iterator.next(); - String outputName = entry.getKey(); - OnnxValue onnxValue = (OnnxValue)entry.getValue(); - if (onnxValue.getType() != OnnxValue.OnnxValueType.ONNX_TYPE_TENSOR) { - throw new Exception("Not supported type: " + onnxValue.getType().toString()); - } - - OnnxTensor onnxTensor = (OnnxTensor)onnxValue; - WritableMap outputTensor = Arguments.createMap(); - - // dims - WritableArray outputDims = Arguments.createArray(); - long[] dims = onnxTensor.getInfo().getShape(); - for (long dim : dims) { - outputDims.pushInt((int)dim); - } - outputTensor.putArray("dims", outputDims); - - // type - outputTensor.putString("type", getJsTensorType(onnxTensor.getInfo().onnxType)); - - // data - if (onnxTensor.getInfo().onnxType == TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - String[] buffer = (String[])onnxTensor.getValue(); - WritableArray dataArray = Arguments.createArray(); - for (String value : buffer) { - dataArray.pushString(value); - } - outputTensor.putArray("data", dataArray); - } else { - // Store in BlobModule then create a blob object as data - byte[] bufferArray = createOutputTensor(onnxTensor); - WritableMap data = Arguments.createMap(); - data.putString("blobId", blobModule.store(bufferArray)); - data.putInt("offset", 0); - data.putInt("size", bufferArray.length); - outputTensor.putMap("data", data); - } - - outputTensorMap.putMap(outputName, outputTensor); - } - - return outputTensorMap; - } - - private static OnnxTensor createInputTensor(TensorInfo.OnnxTensorType tensorType, long[] dims, ByteBuffer values, - OrtEnvironment ortEnvironment) throws Exception { - OnnxTensor tensor = null; - switch (tensorType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - FloatBuffer buffer = values.asFloatBuffer(); - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - ByteBuffer buffer = values; - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { - ShortBuffer buffer = values.asShortBuffer(); - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - IntBuffer buffer = values.asIntBuffer(); - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - LongBuffer buffer = values.asLongBuffer(); - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - DoubleBuffer buffer = values.asDoubleBuffer(); - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - ByteBuffer buffer = values; - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.UINT8); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - ByteBuffer buffer = values; - tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.BOOL); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: - default: - throw new IllegalStateException("Unexpected value: " + tensorType.toString()); - } - - return tensor; - } - - private static byte[] createOutputTensor(OnnxTensor onnxTensor) throws Exception { - TensorInfo tensorInfo = onnxTensor.getInfo(); - ByteBuffer buffer = null; - - int capacity = (int)OrtUtil.elementCount(onnxTensor.getInfo().getShape()); - - switch (tensorInfo.onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - buffer = ByteBuffer.allocate(capacity * 4).order(ByteOrder.nativeOrder()); - buffer.asFloatBuffer().put(onnxTensor.getFloatBuffer()); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - buffer = ByteBuffer.allocate(capacity).order(ByteOrder.nativeOrder()); - buffer.put(onnxTensor.getByteBuffer()); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - buffer = ByteBuffer.allocate(capacity * 2).order(ByteOrder.nativeOrder()); - buffer.asShortBuffer().put(onnxTensor.getShortBuffer()); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - buffer = ByteBuffer.allocate(capacity * 4).order(ByteOrder.nativeOrder()); - buffer.asIntBuffer().put(onnxTensor.getIntBuffer()); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - buffer = ByteBuffer.allocate(capacity * 8).order(ByteOrder.nativeOrder()); - buffer.asLongBuffer().put(onnxTensor.getLongBuffer()); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - buffer = ByteBuffer.allocate(capacity * 8).order(ByteOrder.nativeOrder()); - buffer.asDoubleBuffer().put(onnxTensor.getDoubleBuffer()); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - buffer = ByteBuffer.allocate(capacity).order(ByteOrder.nativeOrder()); - buffer.put(onnxTensor.getByteBuffer()); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - default: - throw new IllegalStateException("Unexpected type: " + tensorInfo.onnxType.toString()); - } - - return buffer.array(); - } - - private static final Map JsTensorTypeToOnnxTensorTypeMap = - Stream - .of(new Object[][] { - {JsTensorTypeFloat, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, - {JsTensorTypeByte, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, - {JsTensorTypeUnsignedByte, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, - {JsTensorTypeShort, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, - {JsTensorTypeInt, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, - {JsTensorTypeLong, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, - {JsTensorTypeString, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING}, - {JsTensorTypeBool, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL}, - {JsTensorTypeDouble, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, - }) - .collect(Collectors.toMap(p -> (String)p[0], p -> (TensorInfo.OnnxTensorType)p[1])); - - private static TensorInfo.OnnxTensorType getOnnxTensorType(String type) { - if (JsTensorTypeToOnnxTensorTypeMap.containsKey(type)) { - return JsTensorTypeToOnnxTensorTypeMap.get(type); - } else { - return TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - } - } - - private static final Map OnnxTensorTypeToJsTensorTypeMap = - Stream - .of(new Object[][] { - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, JsTensorTypeFloat}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, JsTensorTypeByte}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, JsTensorTypeUnsignedByte}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, JsTensorTypeShort}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, JsTensorTypeInt}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, JsTensorTypeLong}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, JsTensorTypeString}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, JsTensorTypeBool}, - {TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, JsTensorTypeDouble}, - }) - .collect(Collectors.toMap(p -> (TensorInfo.OnnxTensorType)p[0], p -> (String)p[1])); - - private static String getJsTensorType(TensorInfo.OnnxTensorType type) { - if (OnnxTensorTypeToJsTensorTypeMap.containsKey(type)) { - return OnnxTensorTypeToJsTensorTypeMap.get(type); - } else { - return "undefined"; - } - } -} diff --git a/js/react_native/cpp/AsyncWorker.h b/js/react_native/cpp/AsyncWorker.h new file mode 100644 index 0000000000000..ceca5c0ac203e --- /dev/null +++ b/js/react_native/cpp/AsyncWorker.h @@ -0,0 +1,131 @@ +#pragma once + +#include "Env.h" +#include +#include +#include +#include +#include +#include + +using namespace facebook::jsi; + +namespace onnxruntimejsi { + +/** + * @brief AsyncWorker is a helper class to run a function asynchronously and + * return a promise. + * + * @param rt The runtime to use. + * @param env The environment to use. + */ +class AsyncWorker : public HostObject, public std::enable_shared_from_this { + public: + AsyncWorker(Runtime& rt, std::shared_ptr env) : rt_(rt), env_(env), cancel_(false) {} + + ~AsyncWorker() { + if (worker_.joinable()) { + if (worker_.get_id() != std::this_thread::get_id()) { + cancel_ = true; + onAbort(); + worker_.join(); + } else { + worker_.detach(); + } + } + } + + /** + * @brief Make sure the value won't be garbage collected during the async + * operation. + * + * @param rt The runtime to use. + * @param value The value to keep. + */ + void keepValue(Runtime& rt, const Value& value) { + keptValues_.push_back(std::make_shared(rt, value)); + } + + /** + * @brief Create a promise to be used in the async operation. + * + * @param rt The runtime to use. + * @return The promise. + */ + Value toPromise(Runtime& rt) { + auto promiseCtor = rt.global().getPropertyAsFunction(rt, "Promise"); + + auto promise = promiseCtor.callAsConstructor( + rt, Function::createFromHostFunction( + rt, PropNameID::forAscii(rt, "executor"), 2, + [this](Runtime& rt, const Value& thisVal, const Value* args, + size_t count) -> Value { + resolveFunc_ = std::make_shared(rt, args[0]); + rejectFunc_ = std::make_shared(rt, args[1]); + cancel_ = false; + worker_ = std::thread([this]() { + if (cancel_) return; + try { + execute(); + dispatchResolve(); + } catch (const std::exception& e) { + dispatchReject(e.what()); + } + }); + return Value::undefined(); + })); + promise.asObject(rt).setProperty(rt, "__nativeWorker", Object::createFromHostObject(rt, shared_from_this())); + return promise; + } + + protected: + virtual void execute() = 0; + + virtual Value onResolve(Runtime& rt) = 0; + virtual Value onReject(Runtime& rt, const std::string& err) { + return String::createFromUtf8(rt, err); + } + + virtual void onAbort() {} + + private: + void dispatchResolve() { + if (cancel_) return; + auto self = shared_from_this(); + env_->runOnJsThread([self]() { + auto resVal = self->onResolve(self->rt_); + self->resolveFunc_->asObject(self->rt_) + .asFunction(self->rt_) + .call(self->rt_, resVal); + self->clearKeeps(); + }); + } + + void dispatchReject(const std::string& err) { + if (cancel_) return; + auto self = shared_from_this(); + env_->runOnJsThread([self, err]() { + auto resVal = self->onReject(self->rt_, err); + self->rejectFunc_->asObject(self->rt_) + .asFunction(self->rt_) + .call(self->rt_, resVal); + self->clearKeeps(); + }); + } + + void clearKeeps() { + keptValues_.clear(); + resolveFunc_.reset(); + rejectFunc_.reset(); + } + + Runtime& rt_; + std::shared_ptr env_; + std::atomic cancel_; + std::vector> keptValues_; + std::shared_ptr resolveFunc_; + std::shared_ptr rejectFunc_; + std::thread worker_; +}; + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/Env.h b/js/react_native/cpp/Env.h new file mode 100644 index 0000000000000..9e7b7651a971f --- /dev/null +++ b/js/react_native/cpp/Env.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include +#include +#include +#include "onnxruntime_cxx_api.h" +#include + +namespace onnxruntimejsi { + +class Env : public std::enable_shared_from_this { + public: + Env(std::shared_ptr jsInvoker) + : jsInvoker_(jsInvoker) {} + + ~Env() {} + + inline void initOrtEnv(OrtLoggingLevel logLevel, const char* logid) { + if (ortEnv_) { + return; + } + ortEnv_ = std::make_shared(logLevel, logid); + } + + inline void setTensorConstructor( + std::shared_ptr tensorConstructor) { + tensorConstructor_ = tensorConstructor; + } + + inline facebook::jsi::Value + getTensorConstructor(facebook::jsi::Runtime& runtime) const { + return tensorConstructor_->lock(runtime); + } + + inline Ort::Env& getOrtEnv() const { return *ortEnv_; } + + inline void runOnJsThread(std::function&& func) { + if (!jsInvoker_) return; + jsInvoker_->invokeAsync(std::move(func)); + } + + private: + std::shared_ptr jsInvoker_; + std::shared_ptr tensorConstructor_; + std::shared_ptr ortEnv_; +}; + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/InferenceSessionHostObject.cpp b/js/react_native/cpp/InferenceSessionHostObject.cpp new file mode 100644 index 0000000000000..c8efd49e6d669 --- /dev/null +++ b/js/react_native/cpp/InferenceSessionHostObject.cpp @@ -0,0 +1,312 @@ +#include "InferenceSessionHostObject.h" +#include "AsyncWorker.h" +#include "JsiUtils.h" +#include "SessionUtils.h" +#include "TensorUtils.h" +#include + +using namespace facebook::jsi; + +namespace onnxruntimejsi { + +class InferenceSessionHostObject::LoadModelAsyncWorker : public AsyncWorker { + public: + LoadModelAsyncWorker(Runtime& runtime, const Value* arguments, size_t count, + std::shared_ptr session) + : AsyncWorker(runtime, session->env_), session_(session) { + if (count < 1) + throw JSError(runtime, "loadModel requires at least 1 argument"); + if (arguments[0].isString()) { + modelPath_ = arguments[0].asString(runtime).utf8(runtime); + if (modelPath_.find("file:/") == 0) { + modelPath_ = modelPath_.substr(5); + if (modelPath_.find("//") == 0) { + modelPath_ = modelPath_.substr(2); + } + } + } else if (arguments[0].isObject() && + arguments[0].asObject(runtime).isArrayBuffer(runtime)) { + auto arrayBufferObj = arguments[0].asObject(runtime); + auto arrayBuffer = arrayBufferObj.getArrayBuffer(runtime); + modelData_ = arrayBuffer.data(runtime); + modelDataLength_ = arrayBuffer.size(runtime); + } else { + throw JSError(runtime, "Model path or buffer is required"); + } + keepValue(runtime, arguments[0]); + if (count > 1) { + parseSessionOptions(runtime, arguments[1], sessionOptions_); + } + } + + protected: + void execute() { + if (modelPath_.empty()) { + session_->session_ = std::make_shared( + session_->env_->getOrtEnv(), modelData_, modelDataLength_, + sessionOptions_); + } else { + session_->session_ = std::make_shared( + session_->env_->getOrtEnv(), modelPath_.c_str(), sessionOptions_); + } + } + + Value onResolve(Runtime& rt) { return Value::undefined(); } + + private: + std::string error_; + std::string modelPath_; + void* modelData_; + size_t modelDataLength_; + std::shared_ptr session_; + Ort::SessionOptions sessionOptions_; + std::shared_ptr weakResolve_; + std::shared_ptr weakReject_; + std::thread thread_; +}; + +DEFINE_METHOD(InferenceSessionHostObject::loadModel) { + auto self = shared_from_this(); + auto worker = + std::make_shared(runtime, arguments, count, self); + return worker->toPromise(runtime); +} + +class InferenceSessionHostObject::RunAsyncWorker : public AsyncWorker { + public: + RunAsyncWorker(Runtime& runtime, const Value* arguments, size_t count, + std::shared_ptr session) + : AsyncWorker(runtime, session->env_), + env_(session->env_), + session_(session->session_), + memoryInfo_(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault)) { + if (count < 1) + throw JSError(runtime, "run requires at least 1 argument"); + if (count > 2 && !arguments[2].isUndefined()) { + parseRunOptions(runtime, arguments[2], runOptions_); + } + forEach(runtime, arguments[0].asObject(runtime), + [&](const std::string& key, const Value& value, size_t index) { + inputNames_.push_back(key); + inputValues_.push_back(TensorUtils::createOrtValueFromJSTensor( + runtime, value.asObject(runtime), memoryInfo_)); + keepValue(runtime, value); + }); + forEach(runtime, arguments[1].asObject(runtime), + [&](const std::string& key, const Value& value, size_t index) { + outputNames_.push_back(key); + if (value.isObject() && + TensorUtils::isTensor(runtime, value.asObject(runtime))) { + outputValues_.push_back(TensorUtils::createOrtValueFromJSTensor( + runtime, value.asObject(runtime), memoryInfo_)); + jsOutputValues_.push_back(std::make_shared( + runtime, value.asObject(runtime))); + keepValue(runtime, value); + } else { + outputValues_.push_back(Ort::Value()); + jsOutputValues_.push_back(nullptr); + } + }); + } + + protected: + void execute() { + auto inputNames = std::vector(inputNames_.size()); + std::transform(inputNames_.begin(), inputNames_.end(), inputNames.begin(), + [](const std::string& name) { return name.c_str(); }); + auto outputNames = std::vector(outputNames_.size()); + std::transform(outputNames_.begin(), outputNames_.end(), + outputNames.begin(), + [](const std::string& name) { return name.c_str(); }); + auto session = session_.lock(); + if (!session) { + throw std::runtime_error("Session is released"); + } + session->Run(runOptions_, inputNames.data(), inputValues_.data(), + inputValues_.size(), outputNames.data(), + outputValues_.data(), outputValues_.size()); + } + + Value onResolve(Runtime& rt) { + auto resultObject = Object(rt); + auto tensorConstructor = + env_->getTensorConstructor(rt).asObject(rt); + for (size_t i = 0; i < outputValues_.size(); ++i) { + if (jsOutputValues_[i] != nullptr && outputValues_[i].IsTensor()) { + resultObject.setProperty(rt, outputNames_[i].c_str(), + jsOutputValues_[i]->lock(rt)); + } else { + auto tensorObj = TensorUtils::createJSTensorFromOrtValue( + rt, outputValues_[i], tensorConstructor); + resultObject.setProperty(rt, outputNames_[i].c_str(), + Value(rt, tensorObj)); + } + } + return Value(rt, resultObject); + } + + void onAbort() { + runOptions_.SetTerminate(); + } + + private: + std::shared_ptr env_; + std::weak_ptr session_; + Ort::MemoryInfo memoryInfo_; + Ort::RunOptions runOptions_; + std::vector inputNames_; + std::vector inputValues_; + std::vector outputNames_; + std::vector outputValues_; + std::vector> jsOutputValues_; +}; + +DEFINE_METHOD(InferenceSessionHostObject::run) { + auto self = shared_from_this(); + auto worker = + std::make_shared(runtime, arguments, count, self); + return worker->toPromise(runtime); +} + +DEFINE_METHOD(InferenceSessionHostObject::dispose) { + session_.reset(); + return Value::undefined(); +} + +DEFINE_METHOD(InferenceSessionHostObject::endProfiling) { + try { + Ort::AllocatorWithDefaultOptions allocator; + auto filename = session_->EndProfilingAllocated(allocator); + return String::createFromUtf8(runtime, std::string(filename.get())); + } catch (const std::exception& e) { + throw JSError(runtime, std::string(e.what())); + } +} + +DEFINE_GETTER(InferenceSessionHostObject::inputMetadata) { + if (!session_) { + return Array(runtime, 0); + } + try { + Ort::AllocatorWithDefaultOptions allocator; + size_t numInputs = session_->GetInputCount(); + auto array = Array(runtime, numInputs); + + for (size_t i = 0; i < numInputs; i++) { + auto item = Object(runtime); + auto inputName = session_->GetInputNameAllocated(i, allocator); + item.setProperty( + runtime, "name", + String::createFromUtf8(runtime, std::string(inputName.get()))); + + try { + auto typeInfo = session_->GetInputTypeInfo(i); + auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo(); + + // Get data type + auto dataType = tensorInfo.GetElementType(); + item.setProperty(runtime, "type", static_cast(dataType)); + + // Get shape + auto shape = tensorInfo.GetShape(); + auto shapeArray = Array(runtime, shape.size()); + for (size_t j = 0; j < shape.size(); j++) { + shapeArray.setValueAtIndex(runtime, j, + Value(static_cast(shape[j]))); + } + item.setProperty(runtime, "shape", shapeArray); + + item.setProperty(runtime, "isTensor", Value(true)); + + // symbolicDimensions + auto symbolicDimensions = tensorInfo.GetSymbolicDimensions(); + auto symbolicDimensionsArray = + Array(runtime, symbolicDimensions.size()); + for (size_t j = 0; j < symbolicDimensions.size(); j++) { + symbolicDimensionsArray.setValueAtIndex( + runtime, j, + String::createFromUtf8(runtime, symbolicDimensions[j])); + } + item.setProperty(runtime, "symbolicDimensions", + symbolicDimensionsArray); + } catch (const std::exception&) { + // Fallback for unknown types + item.setProperty(runtime, "type", + String::createFromUtf8(runtime, "unknown")); + item.setProperty(runtime, "shape", Array(runtime, 0)); + item.setProperty(runtime, "isTensor", Value(false)); + } + + array.setValueAtIndex(runtime, i, Value(runtime, item)); + } + + return Value(runtime, array); + } catch (const Ort::Exception& e) { + throw JSError(runtime, std::string(e.what())); + } +} + +DEFINE_GETTER(InferenceSessionHostObject::outputMetadata) { + if (!session_) { + return Array(runtime, 0); + } + try { + Ort::AllocatorWithDefaultOptions allocator; + size_t numOutputs = session_->GetOutputCount(); + auto array = Array(runtime, numOutputs); + + for (size_t i = 0; i < numOutputs; i++) { + auto item = Object(runtime); + auto outputName = session_->GetOutputNameAllocated(i, allocator); + item.setProperty( + runtime, "name", + String::createFromUtf8(runtime, std::string(outputName.get()))); + + try { + auto typeInfo = session_->GetOutputTypeInfo(i); + auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo(); + + // Get data type + auto dataType = tensorInfo.GetElementType(); + item.setProperty(runtime, "type", static_cast(dataType)); + + // Get shape + auto shape = tensorInfo.GetShape(); + auto shapeArray = Array(runtime, shape.size()); + for (size_t j = 0; j < shape.size(); j++) { + shapeArray.setValueAtIndex(runtime, j, + Value(static_cast(shape[j]))); + } + item.setProperty(runtime, "shape", shapeArray); + + item.setProperty(runtime, "isTensor", Value(true)); + + // symbolicDimensions + auto symbolicDimensions = tensorInfo.GetSymbolicDimensions(); + auto symbolicDimensionsArray = + Array(runtime, symbolicDimensions.size()); + for (size_t j = 0; j < symbolicDimensions.size(); j++) { + symbolicDimensionsArray.setValueAtIndex( + runtime, j, + String::createFromUtf8(runtime, symbolicDimensions[j])); + } + item.setProperty(runtime, "symbolicDimensions", + symbolicDimensionsArray); + } catch (const std::exception&) { + // Fallback for unknown types + item.setProperty(runtime, "type", + String::createFromUtf8(runtime, "unknown")); + item.setProperty(runtime, "shape", Array(runtime, 0)); + item.setProperty(runtime, "isTensor", Value(false)); + } + + array.setValueAtIndex(runtime, i, Value(runtime, item)); + } + + return Value(runtime, array); + } catch (const Ort::Exception& e) { + throw JSError(runtime, std::string(e.what())); + } +} + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/InferenceSessionHostObject.h b/js/react_native/cpp/InferenceSessionHostObject.h new file mode 100644 index 0000000000000..f13a8d46d4048 --- /dev/null +++ b/js/react_native/cpp/InferenceSessionHostObject.h @@ -0,0 +1,55 @@ +#pragma once + +#include "Env.h" +#include "JsiHelper.h" +#include +#include +#include "onnxruntime_cxx_api.h" +#include + +using namespace facebook::jsi; + +namespace onnxruntimejsi { + +class InferenceSessionHostObject + : public HostObjectHelper, + public std::enable_shared_from_this { + public: + InferenceSessionHostObject(std::shared_ptr env) : HostObjectHelper({ + METHOD_INFO(InferenceSessionHostObject, loadModel, 2), + METHOD_INFO(InferenceSessionHostObject, run, 2), + METHOD_INFO(InferenceSessionHostObject, dispose, 0), + METHOD_INFO(InferenceSessionHostObject, endProfiling, 0), + }, + { + GETTER_INFO(InferenceSessionHostObject, inputMetadata), + GETTER_INFO(InferenceSessionHostObject, outputMetadata), + }), + env_(env) {} + + static inline facebook::jsi::Value + constructor(std::shared_ptr env, facebook::jsi::Runtime& runtime, + const facebook::jsi::Value& thisValue, + const facebook::jsi::Value* arguments, size_t count) { + return facebook::jsi::Object::createFromHostObject( + runtime, std::make_shared(env)); + } + + protected: + class LoadModelAsyncWorker; + class RunAsyncWorker; + + private: + std::shared_ptr env_; + std::shared_ptr session_; + + DEFINE_METHOD(loadModel); + DEFINE_METHOD(run); + DEFINE_METHOD(dispose); + DEFINE_METHOD(endProfiling); + + DEFINE_GETTER(inputMetadata); + DEFINE_GETTER(outputMetadata); +}; + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/JsiHelper.h b/js/react_native/cpp/JsiHelper.h new file mode 100644 index 0000000000000..953429e2c26d5 --- /dev/null +++ b/js/react_native/cpp/JsiHelper.h @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include +#include +#include + +#define BIND_METHOD(method) \ + std::bind(&method, std::placeholders::_1, std::placeholders::_2, \ + std::placeholders::_3, std::placeholders::_4) + +#define BIND_GETTER(method) std::bind(&method, std::placeholders::_1) + +#define BIND_SETTER(method) \ + std::bind(&method, std::placeholders::_1, std::placeholders::_2) + +#define BIND_THIS_METHOD(cls, name) \ + std::bind(&cls::name##_method, this, std::placeholders::_1, \ + std::placeholders::_2, std::placeholders::_3, \ + std::placeholders::_4) + +#define BIND_THIS_GETTER(cls, name) \ + std::bind(&cls::name##_get, this, std::placeholders::_1) + +#define BIND_THIS_SETTER(cls, name) \ + std::bind(&cls::name##_set, this, std::placeholders::_1, \ + std::placeholders::_2) + +#define METHOD_INFO(cls, name, count) \ + { \ + #name, { BIND_THIS_METHOD(cls, name), count } \ + } + +#define GETTER_INFO(cls, name) \ + {#name, BIND_THIS_GETTER(cls, name)} + +#define DEFINE_METHOD(name) \ + Value name##_method(Runtime& runtime, const Value& thisValue, \ + const Value* arguments, size_t count) + +#define DEFINE_GETTER(name) Value name##_get(Runtime& runtime) + +#define DEFINE_SETTER(name) \ + void name##_set(Runtime& runtime, const Value& value) + +typedef std::function + JsiMethod; +typedef std::function + JsiGetter; +typedef std::function + JsiSetter; + +struct JsiMethodInfo { + JsiMethod method; + size_t count; +}; + +typedef std::unordered_map JsiMethodMap; +typedef std::unordered_map JsiGetterMap; +typedef std::unordered_map JsiSetterMap; + +class HostObjectHelper : public facebook::jsi::HostObject { + public: + HostObjectHelper( + JsiMethodMap methods = {}, + JsiGetterMap getters = {}, + JsiSetterMap setters = {}) + : methods_(methods), + getters_(getters), + setters_(setters) {} + + std::vector + getPropertyNames(facebook::jsi::Runtime& runtime) override { + std::vector names; + for (auto& [name, _] : methods_) { + names.push_back(facebook::jsi::PropNameID::forUtf8(runtime, name)); + } + for (auto& [name, _] : getters_) { + names.push_back(facebook::jsi::PropNameID::forUtf8(runtime, name)); + } + return names; + } + + facebook::jsi::Value get(facebook::jsi::Runtime& runtime, + const facebook::jsi::PropNameID& name) override { + auto method = methods_.find(name.utf8(runtime)); + if (method != methods_.end()) { + return facebook::jsi::Function::createFromHostFunction(runtime, name, method->second.count, + method->second.method); + } + + auto getter = getters_.find(name.utf8(runtime)); + if (getter != getters_.end()) { + return getter->second(runtime); + } + + return facebook::jsi::Value::undefined(); + } + + void set(facebook::jsi::Runtime& runtime, const facebook::jsi::PropNameID& name, + const facebook::jsi::Value& value) override { + auto setter = setters_.find(name.utf8(runtime)); + if (setter != setters_.end()) { + setter->second(runtime, value); + } + } + + private: + JsiMethodMap methods_; + JsiGetterMap getters_; + JsiSetterMap setters_; +}; diff --git a/js/react_native/cpp/JsiMain.cpp b/js/react_native/cpp/JsiMain.cpp new file mode 100644 index 0000000000000..26e8842b32793 --- /dev/null +++ b/js/react_native/cpp/JsiMain.cpp @@ -0,0 +1,98 @@ +#include "JsiMain.h" +#include "InferenceSessionHostObject.h" +#include "JsiHelper.h" +#include "SessionUtils.h" +#include + +using namespace facebook::jsi; + +namespace onnxruntimejsi { + +std::shared_ptr +install(Runtime& runtime, + std::shared_ptr jsInvoker) { + auto env = std::make_shared(jsInvoker); + try { + auto ortApi = Object(runtime); + + auto initOrtOnceMethod = Function::createFromHostFunction( + runtime, PropNameID::forAscii(runtime, "initOrtOnce"), 2, + [env](Runtime& runtime, const Value& thisValue, const Value* arguments, + size_t count) -> Value { + try { + OrtLoggingLevel logLevel = ORT_LOGGING_LEVEL_WARNING; + if (count > 0 && arguments[0].isNumber()) { + int level = static_cast(arguments[0].asNumber()); + switch (level) { + case 0: + logLevel = ORT_LOGGING_LEVEL_VERBOSE; + break; + case 1: + logLevel = ORT_LOGGING_LEVEL_INFO; + break; + case 2: + logLevel = ORT_LOGGING_LEVEL_WARNING; + break; + case 3: + logLevel = ORT_LOGGING_LEVEL_ERROR; + break; + case 4: + logLevel = ORT_LOGGING_LEVEL_FATAL; + break; + default: + logLevel = ORT_LOGGING_LEVEL_WARNING; + break; + } + } + env->setTensorConstructor(std::make_shared( + runtime, arguments[1].asObject(runtime))); + env->initOrtEnv(logLevel, "onnxruntime-react-native-jsi"); + return Value::undefined(); + } catch (const std::exception& e) { + throw JSError(runtime, "Failed to initialize ONNX Runtime: " + + std::string(e.what())); + } + }); + + ortApi.setProperty(runtime, "initOrtOnce", initOrtOnceMethod); + + auto createInferenceSessionMethod = Function::createFromHostFunction( + runtime, PropNameID::forAscii(runtime, "createInferenceSession"), 0, + std::bind(InferenceSessionHostObject::constructor, env, + std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4)); + ortApi.setProperty(runtime, "createInferenceSession", + createInferenceSessionMethod); + + auto listSupportedBackendsMethod = Function::createFromHostFunction( + runtime, PropNameID::forAscii(runtime, "listSupportedBackends"), 0, + [](Runtime& runtime, const Value& thisValue, const Value* arguments, + size_t count) -> Value { + auto backends = Array(runtime, supportedBackends.size()); + for (size_t i = 0; i < supportedBackends.size(); i++) { + auto backend = Object(runtime); + backend.setProperty( + runtime, "name", + String::createFromUtf8(runtime, supportedBackends[i])); + backends.setValueAtIndex(runtime, i, Value(runtime, backend)); + } + return Value(runtime, backends); + }); + + ortApi.setProperty(runtime, "listSupportedBackends", + listSupportedBackendsMethod); + + ortApi.setProperty( + runtime, "version", + String::createFromUtf8(runtime, OrtGetApiBase()->GetVersionString())); + + runtime.global().setProperty(runtime, "OrtApi", ortApi); + } catch (const std::exception& e) { + throw JSError(runtime, "Failed to install ONNX Runtime JSI bindings: " + + std::string(e.what())); + } + + return env; +} + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/JsiMain.h b/js/react_native/cpp/JsiMain.h new file mode 100644 index 0000000000000..15c8da084c746 --- /dev/null +++ b/js/react_native/cpp/JsiMain.h @@ -0,0 +1,13 @@ +#pragma once + +#include "Env.h" +#include +#include + +namespace onnxruntimejsi { + +std::shared_ptr +install(facebook::jsi::Runtime& runtime, + std::shared_ptr jsInvoker = nullptr); + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/JsiUtils.cpp b/js/react_native/cpp/JsiUtils.cpp new file mode 100644 index 0000000000000..a5c802f36ef30 --- /dev/null +++ b/js/react_native/cpp/JsiUtils.cpp @@ -0,0 +1,32 @@ +#include "JsiUtils.h" + +using namespace facebook::jsi; + +bool isTypedArray(Runtime& runtime, const Object& jsObj) { + if (!jsObj.hasProperty(runtime, "buffer")) + return false; + if (!jsObj.getProperty(runtime, "buffer") + .asObject(runtime) + .isArrayBuffer(runtime)) + return false; + return true; +} + +void forEach(Runtime& runtime, const Object& object, + const std::function& callback) { + auto names = object.getPropertyNames(runtime); + for (size_t i = 0; i < names.size(runtime); i++) { + auto key = + names.getValueAtIndex(runtime, i).asString(runtime).utf8(runtime); + auto value = object.getProperty(runtime, key.c_str()); + callback(key, value, i); + } +} + +void forEach(Runtime& runtime, const Array& array, + const std::function& callback) { + for (size_t i = 0; i < array.size(runtime); i++) { + callback(array.getValueAtIndex(runtime, i), i); + } +} diff --git a/js/react_native/cpp/JsiUtils.h b/js/react_native/cpp/JsiUtils.h new file mode 100644 index 0000000000000..be38d67868df4 --- /dev/null +++ b/js/react_native/cpp/JsiUtils.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include + +bool isTypedArray(facebook::jsi::Runtime& runtime, + const facebook::jsi::Object& jsObj); + +void forEach( + facebook::jsi::Runtime& runtime, const facebook::jsi::Object& object, + const std::function& callback); + +void forEach( + facebook::jsi::Runtime& runtime, const facebook::jsi::Array& array, + const std::function& callback); diff --git a/js/react_native/cpp/SessionUtils.cpp b/js/react_native/cpp/SessionUtils.cpp new file mode 100644 index 0000000000000..3e6672ec546fb --- /dev/null +++ b/js/react_native/cpp/SessionUtils.cpp @@ -0,0 +1,450 @@ +#include "SessionUtils.h" +#include "JsiUtils.h" +#include "cpu_provider_factory.h" +#include +#include "onnxruntime_cxx_api.h" +#ifdef USE_NNAPI +#include "nnapi_provider_factory.h" +#endif +#ifdef USE_COREML +#include "coreml_provider_factory.h" +#endif + +// Note: Using below syntax for including ort c api and ort extensions headers to resolve a compiling error happened +// in an expo react native ios app when ort extensions enabled (a redefinition error of multiple object types defined +// within ORT C API header). It's an edge case that compiler allows both ort c api headers to be included when #include +// syntax doesn't match. For the case when extensions not enabled, it still requires a onnxruntime prefix directory for +// searching paths. Also in general, it's a convention to use #include for C/C++ headers rather then #import. See: +// https://google.github.io/styleguide/objcguide.html#import-and-include +// https://microsoft.github.io/objc-guide/Headers/ImportAndInclude.html +#if defined(ORT_ENABLE_EXTENSIONS) && defined(__APPLE__) +#include +#endif + +using namespace facebook::jsi; + +namespace onnxruntimejsi { + +const std::vector supportedBackends = { + "cpu", + "xnnpack", +#ifdef USE_COREML + "coreml", +#endif +#ifdef USE_NNAPI + "nnapi", +#endif +#ifdef USE_QNN + "qnn", +#endif +}; + +class ExtendedSessionOptions : public Ort::SessionOptions { + public: + ExtendedSessionOptions() = default; + + void AppendExecutionProvider_CPU(int use_arena) { + Ort::ThrowOnError( + OrtSessionOptionsAppendExecutionProvider_CPU(this->p_, use_arena)); + } + + void AddFreeDimensionOverrideByName(const char* name, int64_t value) { + Ort::ThrowOnError( + Ort::GetApi().AddFreeDimensionOverrideByName(this->p_, name, value)); + } +#ifdef USE_NNAPI + void AppendExecutionProvider_Nnapi(uint32_t nnapi_flags) { + Ort::ThrowOnError( + OrtSessionOptionsAppendExecutionProvider_Nnapi(this->p_, nnapi_flags)); + } +#endif +#ifdef USE_COREML + void AppendExecutionProvider_CoreML(int flags) { + Ort::ThrowOnError( + OrtSessionOptionsAppendExecutionProvider_CoreML(this->p_, flags)); + } +#endif +}; + +void parseSessionOptions(Runtime& runtime, const Value& optionsValue, + Ort::SessionOptions& sessionOptions) { + if (!optionsValue.isObject()) + return; + + auto options = optionsValue.asObject(runtime); + + try { +#ifdef ORT_ENABLE_EXTENSIONS + // ortExtLibPath + if (options.hasProperty(runtime, "ortExtLibPath")) { +#ifdef __APPLE__ + Ort::ThrowOnError(RegisterCustomOps(sessionOptions, OrtGetApiBase())); +#endif +#ifdef __ANDROID__ + auto prop = options.getProperty(runtime, "ortExtLibPath"); + if (prop.isString()) { + std::string libraryPath = prop.asString(runtime).utf8(runtime); + sessionOptions.RegisterCustomOpsLibrary(libraryPath.c_str()); + } +#endif + } +#endif + + // intraOpNumThreads + if (options.hasProperty(runtime, "intraOpNumThreads")) { + auto prop = options.getProperty(runtime, "intraOpNumThreads"); + if (prop.isNumber()) { + int numThreads = static_cast(prop.asNumber()); + if (numThreads > 0) { + sessionOptions.SetIntraOpNumThreads(numThreads); + } + } + } + + // interOpNumThreads + if (options.hasProperty(runtime, "interOpNumThreads")) { + auto prop = options.getProperty(runtime, "interOpNumThreads"); + if (prop.isNumber()) { + int numThreads = static_cast(prop.asNumber()); + if (numThreads > 0) { + sessionOptions.SetInterOpNumThreads(numThreads); + } + } + } + + // freeDimensionOverrides + if (options.hasProperty(runtime, "freeDimensionOverrides")) { + auto prop = options.getProperty(runtime, "freeDimensionOverrides"); + if (prop.isObject()) { + auto overrides = prop.asObject(runtime); + forEach(runtime, overrides, + [&](const std::string& key, const Value& value, size_t index) { + reinterpret_cast(sessionOptions) + .AddFreeDimensionOverrideByName( + key.c_str(), static_cast(value.asNumber())); + }); + } + } + + // graphOptimizationLevel + if (options.hasProperty(runtime, "graphOptimizationLevel")) { + auto prop = options.getProperty(runtime, "graphOptimizationLevel"); + if (prop.isString()) { + std::string level = prop.asString(runtime).utf8(runtime); + if (level == "disabled") { + sessionOptions.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + } else if (level == "basic") { + sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); + } else if (level == "extended") { + sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED); + } else if (level == "all") { + sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL); + } + } + } + + // enableCpuMemArena + if (options.hasProperty(runtime, "enableCpuMemArena")) { + auto prop = options.getProperty(runtime, "enableCpuMemArena"); + if (prop.isBool()) { + if (prop.asBool()) { + sessionOptions.EnableCpuMemArena(); + } else { + sessionOptions.DisableCpuMemArena(); + } + } + } + + // enableMemPattern + if (options.hasProperty(runtime, "enableMemPattern")) { + auto prop = options.getProperty(runtime, "enableMemPattern"); + if (prop.isBool()) { + if (prop.asBool()) { + sessionOptions.EnableMemPattern(); + } else { + sessionOptions.DisableMemPattern(); + } + } + } + + // executionMode + if (options.hasProperty(runtime, "executionMode")) { + auto prop = options.getProperty(runtime, "executionMode"); + if (prop.isString()) { + std::string mode = prop.asString(runtime).utf8(runtime); + if (mode == "sequential") { + sessionOptions.SetExecutionMode(ORT_SEQUENTIAL); + } else if (mode == "parallel") { + sessionOptions.SetExecutionMode(ORT_PARALLEL); + } + } + } + + // optimizedModelFilePath + if (options.hasProperty(runtime, "optimizedModelFilePath")) { + auto prop = options.getProperty(runtime, "optimizedModelFilePath"); + if (prop.isString()) { + std::string path = prop.asString(runtime).utf8(runtime); + sessionOptions.SetOptimizedModelFilePath(path.c_str()); + } + } + + // enableProfiling + if (options.hasProperty(runtime, "enableProfiling")) { + auto prop = options.getProperty(runtime, "enableProfiling"); + if (prop.isBool() && prop.asBool()) { + sessionOptions.EnableProfiling("onnxruntime_profile_"); + } + } + + // profileFilePrefix + if (options.hasProperty(runtime, "profileFilePrefix")) { + auto enableProfilingProp = + options.getProperty(runtime, "enableProfiling"); + if (enableProfilingProp.isBool() && enableProfilingProp.asBool()) { + auto prop = options.getProperty(runtime, "profileFilePrefix"); + if (prop.isString()) { + std::string prefix = prop.asString(runtime).utf8(runtime); + sessionOptions.EnableProfiling(prefix.c_str()); + } + } + } + + // logId + if (options.hasProperty(runtime, "logId")) { + auto prop = options.getProperty(runtime, "logId"); + if (prop.isString()) { + std::string logId = prop.asString(runtime).utf8(runtime); + sessionOptions.SetLogId(logId.c_str()); + } + } + + // logSeverityLevel + if (options.hasProperty(runtime, "logSeverityLevel")) { + auto prop = options.getProperty(runtime, "logSeverityLevel"); + if (prop.isNumber()) { + int level = static_cast(prop.asNumber()); + if (level >= 0 && level <= 4) { + sessionOptions.SetLogSeverityLevel(level); + } + } + } + + // externalData + if (options.hasProperty(runtime, "externalData")) { + auto prop = + options.getProperty(runtime, "externalData").asObject(runtime); + if (prop.isArray(runtime)) { + auto externalDataArray = prop.asArray(runtime); + std::vector paths; + std::vector buffs; + std::vector sizes; + forEach( + runtime, externalDataArray, [&](const Value& value, size_t index) { + if (value.isObject()) { + auto externalDataObject = value.asObject(runtime); + if (externalDataObject.hasProperty(runtime, "path")) { + auto pathValue = + externalDataObject.getProperty(runtime, "path"); + if (pathValue.isString()) { + paths.push_back(pathValue.asString(runtime).utf8(runtime)); + } + } + if (externalDataObject.hasProperty(runtime, "data")) { + auto dataValue = + externalDataObject.getProperty(runtime, "data") + .asObject(runtime); + if (isTypedArray(runtime, dataValue)) { + auto arrayBuffer = dataValue.getProperty(runtime, "buffer") + .asObject(runtime) + .getArrayBuffer(runtime); + buffs.push_back( + reinterpret_cast(arrayBuffer.data(runtime))); + sizes.push_back(arrayBuffer.size(runtime)); + } + } + } + }); + sessionOptions.AddExternalInitializersFromFilesInMemory(paths, buffs, + sizes); + } + } + + // executionProviders + if (options.hasProperty(runtime, "executionProviders")) { + auto prop = options.getProperty(runtime, "executionProviders"); + if (prop.isObject() && prop.asObject(runtime).isArray(runtime)) { + auto providers = prop.asObject(runtime).asArray(runtime); + forEach(runtime, providers, [&](const Value& epValue, size_t index) { + std::string epName; + std::unique_ptr providerObj; + if (epValue.isString()) { + epName = epValue.asString(runtime).utf8(runtime); + } else if (epValue.isObject()) { + providerObj = std::make_unique(epValue.asObject(runtime)); + epName = providerObj->getProperty(runtime, "name") + .asString(runtime) + .utf8(runtime); + } + + // Apply execution providers + if (epName == "cpu") { + int use_arena = 0; + if (providerObj && providerObj->hasProperty(runtime, "useArena")) { + auto useArena = providerObj->getProperty(runtime, "useArena"); + if (useArena.isBool() && useArena.asBool()) { + use_arena = 1; + } + } + reinterpret_cast(sessionOptions) + .AppendExecutionProvider_CPU(use_arena); + } else if (epName == "xnnpack") { + sessionOptions.AppendExecutionProvider("XNNPACK"); + } +#ifdef USE_COREML + else if (epName == "coreml") { + int flags = 0; + if (providerObj && + providerObj->hasProperty(runtime, "coreMlFlags")) { + auto flagsValue = + providerObj->getProperty(runtime, "coreMlFlags"); + if (flagsValue.isNumber()) { + flags = static_cast(flagsValue.asNumber()); + } + } + reinterpret_cast(sessionOptions) + .AppendExecutionProvider_CoreML(flags); + } +#endif +#ifdef USE_NNAPI + else if (epName == "nnapi") { + uint32_t nnapi_flags = 0; + if (providerObj && providerObj->hasProperty(runtime, "useFP16")) { + auto useFP16 = providerObj->getProperty(runtime, "useFP16"); + if (useFP16.isBool() && useFP16.asBool()) { + nnapi_flags |= NNAPI_FLAG_USE_FP16; + } + } + if (providerObj && providerObj->hasProperty(runtime, "useNCHW")) { + auto useNCHW = providerObj->getProperty(runtime, "useNCHW"); + if (useNCHW.isBool() && useNCHW.asBool()) { + nnapi_flags |= NNAPI_FLAG_USE_NCHW; + } + } + if (providerObj && + providerObj->hasProperty(runtime, "cpuDisabled")) { + auto cpuDisabled = + providerObj->getProperty(runtime, "cpuDisabled"); + if (cpuDisabled.isBool() && cpuDisabled.asBool()) { + nnapi_flags |= NNAPI_FLAG_CPU_DISABLED; + } + } + if (providerObj && providerObj->hasProperty(runtime, "cpuOnly")) { + auto cpuOnly = providerObj->getProperty(runtime, "cpuOnly"); + if (cpuOnly.isBool() && cpuOnly.asBool()) { + nnapi_flags |= NNAPI_FLAG_CPU_ONLY; + } + } + reinterpret_cast(sessionOptions) + .AppendExecutionProvider_Nnapi(nnapi_flags); + } +#endif +#ifdef USE_QNN + else if (epName == "qnn") { + std::unordered_map options; + if (providerObj && + providerObj->hasProperty(runtime, "backendType")) { + options["backendType"] = + providerObj->getProperty(runtime, "backendType") + .asString(runtime) + .utf8(runtime); + } + if (providerObj && + providerObj->hasProperty(runtime, "backendPath")) { + options["backendPath"] = + providerObj->getProperty(runtime, "backendPath") + .asString(runtime) + .utf8(runtime); + } + if (providerObj && + providerObj->hasProperty(runtime, "enableFp16Precision")) { + auto enableFp16Precision = + providerObj->getProperty(runtime, "enableFp16Precision"); + if (enableFp16Precision.isBool() && + enableFp16Precision.asBool()) { + options["enableFp16Precision"] = "1"; + } else { + options["enableFp16Precision"] = "0"; + } + } + sessionOptions.AppendExecutionProvider("QNN", options); + } +#endif + else { + throw JSError(runtime, "Unsupported execution provider: " + epName); + } + }); + } + } + } catch (const JSError& e) { + throw e; + } catch (const std::exception& e) { + throw JSError(runtime, + "Failed to parse session options: " + std::string(e.what())); + } +} + +void parseRunOptions(Runtime& runtime, const Value& optionsValue, + Ort::RunOptions& runOptions) { + if (!optionsValue.isObject()) + return; + + auto options = optionsValue.asObject(runtime); + + try { + // tag + if (options.hasProperty(runtime, "tag")) { + auto prop = options.getProperty(runtime, "tag"); + if (prop.isString()) { + std::string tag = prop.asString(runtime).utf8(runtime); + runOptions.SetRunTag(tag.c_str()); + } + } + + // logSeverityLevel + if (options.hasProperty(runtime, "logSeverityLevel")) { + auto prop = options.getProperty(runtime, "logSeverityLevel"); + if (prop.isNumber()) { + int level = static_cast(prop.asNumber()); + if (level >= 0 && level <= 4) { + runOptions.SetRunLogSeverityLevel(level); + } + } + } + + // logVerbosityLevel + if (options.hasProperty(runtime, "logVerbosityLevel")) { + auto prop = options.getProperty(runtime, "logVerbosityLevel"); + if (prop.isNumber()) { + int level = static_cast(prop.asNumber()); + if (level >= 0) { + runOptions.SetRunLogVerbosityLevel(level); + } + } + } + + // terminate + if (options.hasProperty(runtime, "terminate")) { + auto prop = options.getProperty(runtime, "terminate"); + if (prop.isBool() && prop.asBool()) { + runOptions.SetTerminate(); + } + } + + } catch (const std::exception& e) { + throw JSError(runtime, + "Failed to parse run options: " + std::string(e.what())); + } +} + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/SessionUtils.h b/js/react_native/cpp/SessionUtils.h new file mode 100644 index 0000000000000..4dafcd01ab845 --- /dev/null +++ b/js/react_native/cpp/SessionUtils.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include "onnxruntime_cxx_api.h" + +namespace onnxruntimejsi { + +extern const std::vector supportedBackends; + +void parseSessionOptions(facebook::jsi::Runtime& runtime, + const facebook::jsi::Value& optionsValue, + Ort::SessionOptions& sessionOptions); + +void parseRunOptions(facebook::jsi::Runtime& runtime, + const facebook::jsi::Value& optionsValue, + Ort::RunOptions& runOptions); + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/TensorUtils.cpp b/js/react_native/cpp/TensorUtils.cpp new file mode 100644 index 0000000000000..79d270d883294 --- /dev/null +++ b/js/react_native/cpp/TensorUtils.cpp @@ -0,0 +1,236 @@ +#include "TensorUtils.h" +#include "JsiUtils.h" +#include +#include +#include + +using namespace facebook::jsi; + +namespace onnxruntimejsi { + +static const std::unordered_map + dataTypeToStringMap = { + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, "float32"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, "uint8"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, "int8"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, "uint16"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, "int16"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, "int32"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, "int64"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, "string"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, "bool"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, "float16"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, "float64"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, "uint32"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, "uint64"}, +}; + +static const std::unordered_map + elementSizeMap = { + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, sizeof(float)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, sizeof(int8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, sizeof(uint16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, sizeof(int16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, sizeof(int32_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, sizeof(int64_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, sizeof(char*)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, sizeof(bool)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, 2}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, sizeof(double)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, sizeof(uint32_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, sizeof(uint64_t)}, +}; + +static const std::unordered_map + dataTypeToTypedArrayMap = { + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, "Float32Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, "Float64Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, "Int32Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, "BigInt64Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, "Uint32Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, "BigUint64Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, "Uint8Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, "Int8Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, "Uint16Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, "Int16Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, "Float16Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, "Array"}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, "Uint8Array"}, +}; + +inline size_t getElementSize(ONNXTensorElementDataType dataType) { + auto it = elementSizeMap.find(dataType); + if (it != elementSizeMap.end()) { + return it->second; + } + throw std::invalid_argument("Unsupported or unknown tensor data type: " + + std::to_string(static_cast(dataType))); +} + +bool TensorUtils::isTensor(Runtime& runtime, const Object& obj) { + return obj.hasProperty(runtime, "cpuData") && + obj.hasProperty(runtime, "dims") && obj.hasProperty(runtime, "type"); +} + +inline Object getTypedArrayConstructor(Runtime& runtime, + const ONNXTensorElementDataType type) { + auto it = dataTypeToTypedArrayMap.find(type); + if (it != dataTypeToTypedArrayMap.end()) { + auto prop = runtime.global().getProperty(runtime, it->second); + if (prop.isObject()) { + return prop.asObject(runtime); + } else { + throw JSError(runtime, "TypedArray constructor not found: " + + std::string(it->second)); + } + } + throw JSError(runtime, + "Unsupported tensor data type for TypedArray creation: " + + std::to_string(static_cast(type))); +} + +size_t getElementCount(const std::vector& shape) { + size_t count = 1; + for (auto dim : shape) { + count *= dim; + } + return count; +} + +Ort::Value +TensorUtils::createOrtValueFromJSTensor(Runtime& runtime, + const Object& tensorObj, + const Ort::MemoryInfo& memoryInfo) { + if (!isTensor(runtime, tensorObj)) { + throw JSError( + runtime, + "Invalid tensor object: missing cpuData, dims, or type properties"); + } + + auto dataProperty = tensorObj.getProperty(runtime, "cpuData"); + auto dimsProperty = tensorObj.getProperty(runtime, "dims"); + auto typeProperty = tensorObj.getProperty(runtime, "type"); + + if (!dimsProperty.isObject() || + !dimsProperty.asObject(runtime).isArray(runtime)) { + throw JSError(runtime, "Tensor dims must be array"); + } + + if (!typeProperty.isString()) { + throw JSError(runtime, "Tensor type must be string"); + } + + ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + auto typeStr = typeProperty.asString(runtime).utf8(runtime); + for (auto it = dataTypeToStringMap.begin(); it != dataTypeToStringMap.end(); + ++it) { + if (it->second == typeStr) { + type = it->first; + break; + } + } + if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { + throw JSError(runtime, "Unsupported tensor data type: " + typeStr); + } + + void* data = nullptr; + auto dataObj = dataProperty.asObject(runtime); + + if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + if (!dataObj.isArray(runtime)) { + throw JSError(runtime, "Tensor data must be an array of strings"); + } + auto array = dataObj.asArray(runtime); + auto size = array.size(runtime); + data = new char*[size]; + for (size_t i = 0; i < size; ++i) { + auto item = array.getValueAtIndex(runtime, i); + static_cast(data)[i] = + strdup(item.toString(runtime).utf8(runtime).c_str()); + } + } else { + if (!isTypedArray(runtime, dataObj)) { + throw JSError(runtime, "Tensor data must be a TypedArray"); + } + auto buffer = dataObj.getProperty(runtime, "buffer") + .asObject(runtime) + .getArrayBuffer(runtime); + data = buffer.data(runtime); + } + + std::vector shape; + auto dimsArray = dimsProperty.asObject(runtime).asArray(runtime); + for (size_t i = 0; i < dimsArray.size(runtime); ++i) { + auto dim = dimsArray.getValueAtIndex(runtime, i); + if (dim.isNumber()) { + shape.push_back(static_cast(dim.asNumber())); + } + } + + return Ort::Value::CreateTensor(memoryInfo, data, + getElementCount(shape) * getElementSize(type), + shape.data(), shape.size(), type); +} + +Object +TensorUtils::createJSTensorFromOrtValue(Runtime& runtime, Ort::Value& ortValue, + const Object& tensorConstructor) { + auto typeInfo = ortValue.GetTensorTypeAndShapeInfo(); + auto shape = typeInfo.GetShape(); + auto elementType = typeInfo.GetElementType(); + + std::string typeStr; + auto it = dataTypeToStringMap.find(elementType); + if (it != dataTypeToStringMap.end()) { + typeStr = it->second; + } else { + throw JSError(runtime, + "Unsupported tensor data type for TypedArray creation: " + + std::to_string(static_cast(elementType))); + } + + auto dimsArray = Array(runtime, shape.size()); + for (size_t j = 0; j < shape.size(); ++j) { + dimsArray.setValueAtIndex(runtime, j, Value(static_cast(shape[j]))); + } + + if (elementType != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + void* rawData = ortValue.GetTensorMutableRawData(); + size_t elementCount = + ortValue.GetTensorTypeAndShapeInfo().GetElementCount(); + size_t elementSize = getElementSize(elementType); + size_t dataSize = elementCount * elementSize; + + auto typedArrayCtor = getTypedArrayConstructor(runtime, elementType); + auto typedArrayInstance = + typedArrayCtor.asFunction(runtime).callAsConstructor( + runtime, static_cast(elementCount)); + + auto buffer = typedArrayInstance.asObject(runtime) + .getProperty(runtime, "buffer") + .asObject(runtime) + .getArrayBuffer(runtime); + memcpy(buffer.data(runtime), rawData, dataSize); + + auto tensorInstance = + tensorConstructor.asFunction(runtime).callAsConstructor( + runtime, typeStr, typedArrayInstance, dimsArray); + + return tensorInstance.asObject(runtime); + } else { + auto strArray = Array(runtime, shape.size()); + for (size_t j = 0; j < shape.size(); ++j) { + strArray.setValueAtIndex( + runtime, j, Value(runtime, String::createFromUtf8(runtime, ""))); + } + + auto tensorInstance = + tensorConstructor.asFunction(runtime).callAsConstructor( + runtime, typeStr, strArray, dimsArray); + + return tensorInstance.asObject(runtime); + } +} + +} // namespace onnxruntimejsi diff --git a/js/react_native/cpp/TensorUtils.h b/js/react_native/cpp/TensorUtils.h new file mode 100644 index 0000000000000..5361f5cb1101f --- /dev/null +++ b/js/react_native/cpp/TensorUtils.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include "onnxruntime_cxx_api.h" +#include +#include + +namespace onnxruntimejsi { + +class TensorUtils { + public: + static Ort::Value + createOrtValueFromJSTensor(facebook::jsi::Runtime& runtime, + const facebook::jsi::Object& tensorObj, + const Ort::MemoryInfo& memoryInfo); + + static facebook::jsi::Object + createJSTensorFromOrtValue(facebook::jsi::Runtime& runtime, + Ort::Value& ortValue, + const facebook::jsi::Object& tensorConstructor); + + static bool isTensor(facebook::jsi::Runtime& runtime, + const facebook::jsi::Object& obj); +}; + +} // namespace onnxruntimejsi diff --git a/js/react_native/e2e/android/app/build.gradle b/js/react_native/e2e/android/app/build.gradle index fa94c00a32bd0..54d5e55a209d8 100644 --- a/js/react_native/e2e/android/app/build.gradle +++ b/js/react_native/e2e/android/app/build.gradle @@ -116,17 +116,10 @@ android { } } -repositories { - flatDir { - dir 'libs' - } -} - dependencies { androidTestImplementation('com.wix:detox:+') implementation 'androidx.appcompat:appcompat:1.1.0' - implementation fileTree(dir: "libs", include: ["*.jar"]) // The version of react-native is set by the React Native Gradle Plugin implementation("com.facebook.react:react-android") implementation("com.facebook.react:flipper-integration") @@ -143,8 +136,6 @@ dependencies { androidTestImplementation "androidx.test:rules:1.5.0" implementation (project(':onnxruntime-react-native')) - // specify ORT dependency here so it can be found in libs flatDir repository - implementation "com.microsoft.onnxruntime:onnxruntime-android:latest.integration@aar" } // Run this once to be able to run the application with BUCK diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_bool.onnx b/js/react_native/e2e/android/app/src/main/assets/test_types_bool.onnx similarity index 100% rename from js/react_native/android/src/androidTest/res/raw/test_types_bool.onnx rename to js/react_native/e2e/android/app/src/main/assets/test_types_bool.onnx diff --git a/js/react_native/android/src/androidTest/res/raw/test_types_double.onnx b/js/react_native/e2e/android/app/src/main/assets/test_types_double.onnx similarity index 100% rename from js/react_native/android/src/androidTest/res/raw/test_types_double.onnx rename to js/react_native/e2e/android/app/src/main/assets/test_types_double.onnx diff --git a/js/react_native/e2e/android/app/src/main/assets/test_types_float.ort b/js/react_native/e2e/android/app/src/main/assets/test_types_float.ort new file mode 100644 index 0000000000000000000000000000000000000000..2e8377a4afc1fa336bcd8b6d486e14a44a2cf091 GIT binary patch literal 1824 zcmZ`(&ubGw6n@dQjj`6CL4t&mLy#QG(xNS;cnrwpmDa!@6zbMLkFn z@z8^ae}Q=P;Ga-LJcxMmAU%7M9s`E1-#0r`Hw%66?R)d)&G+8Cnc0$v+`iMiIUz-v zkUT~fv?%<79*l|n_#4L-?3|hvnI6V*N!y2@I{?Q=fyKKW&v6%L&%tgYt^&=q#08)Z z&{qWLYkK}Nl>Cm99J|nmF&;pZJFjbEWz7E}le-EZxj5I9F+`8!t0ZS319Oy+wTi4| zWhZ-Ny$b-+P@ch*!4Z)m@EVI)1EnN2y|$z~WU0M25Ku;39UvF;QxAa}<~U|P>Z^cA zLG{!^g$0b*V=PU{l+4I9DrLT;M`bW3r(($zd@~v~=!*b(n*jTN99Tf@w0nRNxKs_S z!+>+NAE8dxum$}B=mQ;q`?%7$cdhR_YmpcDjWAgEmfa9;A)h*84l=^IHz2M6pI|=# z$iKH2$Lx9)+)uTQVJ?C_u4Cq5GnRExU)ST0I2^ip7mVD3k9=R?{FdU|PVsGK_}*g7 zp69Wb~(zdulSdhsVImTKIM#xjb$p?VU2f z)_)-fYcj&|dEj$qeGNXYH#tp>>Gibd&?&w&zREmyIjPqjREOb9^I(a}Gp{v!#s5&S z7WLMmq<7p;`rpU(CVmb!$J__&^nA5U^`?P|60i&ZHo(3;1ZDu{`H4HowG@uIlOSBR zdrzVublP?+a(&M6XP1sU1(y?a?XKs9!Q)^hvVEt$;Yd9b$NyN%vjgAn+r7|T_uLIT z3R|A<`KzSmn<^r+xjXbxSPTD4BRMdBnc-)>hCj{Aw{OmjT)eL)uG0X=%w{30wPT`43y-|3Lr% literal 0 HcmV?d00001 diff --git a/js/react_native/e2e/android/app/src/main/assets/test_types_int32.ort b/js/react_native/e2e/android/app/src/main/assets/test_types_int32.ort new file mode 100644 index 0000000000000000000000000000000000000000..15d3cc1f8903f79410c95561029604191c5192dd GIT binary patch literal 1824 zcmZ`(O=}ZT6ur^0iLut8g9HgBiy&E)A)+m%xbPbbeguu;!i8ZvZ3dE=FizUIQ5RB* zxadN}U*Ot>KcR@Y5OL)~x^^X91}q)VnR%~H272MlyYIex&pq$GnK4D=?!DISNh!&s z6fiR|N{n9^M-w7H{zj1jKNqix%)Ou3x%2Q_nY=1uC4f9F znLxKZp0VW&^kt6?)KoCiCoTeYfHBrG)^fci4DvflbnRd~jQI#7^%T@TxEuB|!B7PQ zb((68`sV?Z&|bim(Gihd;0>CPk7i0u)tc&r48%m5NGKz(4sf3IK|TB$);MQBj@JN@ zVtl3s`WG?d8KX31)6$R`C}qE7zhy8dw_?aNVhx2F#!CQsTR;~$4$MP4{XU=#p;!ZT zXmE}G6X>LdO^mOAANl&4uvuDrXKpbE%8t4K@FW%VC)Js_hP9%l2s181Dzg+X3%cdp!C7} zZ{(mRZJb{KK4ebUd!{XP3@9g^=y{te$HJ{!@QBv^=}BW=lPj$NyN{wS3PTTK<b)7XU4BD>e zdMl*mH&sGs-lOJcoLcx_8tZ}iMuwkyHGi6y-@d*xa`C^Z^Lky8?-`ymSRdnrwpmDa!@6zbMLkFn z@z8^ae}Q=P;Ga-LJcxMmAU%7M9s`E1-#0r`Hw%66?R)d)&G+8Cnc0$v+`iMiIUz-v zkUT~fv?%<79*l|n_#4L-?3|hvnI6V*N!y2@I{?Q=fyKKW&v6%L&%tgYt^&=q#08)Z z&{qWLYkK}Nl>Cm99J|nmF&;pZJFjbEWz7E}le-EZxj5I9F+`8!t0ZS319Oy+wTi4| zWhZ-Ny$b-+P@ch*!4Z)m@EVI)1EnN2y|$z~WU0M25Ku;39UvF;QxAa}<~U|P>Z^cA zLG{!^g$0b*V=PU{l+4I9DrLT;M`bW3r(($zd@~v~=!*b(n*jTN99Tf@w0nRNxKs_S z!+>+NAE8dxum$}B=mQ;q`?%7$cdhR_YmpcDjWAgEmfa9;A)h*84l=^IHz2M6pI|=# z$iKH2$Lx9)+)uTQVJ?C_u4Cq5GnRExU)ST0I2^ip7mVD3k9=R?{FdU|PVsGK_}*g7 zp69W@~&L{w?ii{}L(4q{?Ke`^GcvsmbIX9s?_D;oE`a^0<+_o2uJxSQ+#QBm3i!PQm;Fx4#St`!4j2cUTgM>|Dj+l z>a9gd@3^1zzmMxp{2XkKxewOq`D&T!O#>4pU>E*vfPH%i%mB>u6L*knDI9YrLAYx7 zoE**CYE+^>PUC#-F$H7Wu`%Zhqk$NVM|FM>52fp98d!f7Txf^yA zwmjeSS4qn^RYYcnrwpmDa!@6zbMLkFn z@z8^ae}Q=P;Ga-LJcxMmAU%7M9s`E1-#0r`Hw%66?R)d)&G+8Cnc0$v+`iMiIUz-v zkUT~fv?%<79*l|n_#4L-?3|hvnI6V*N!y2@I{?Q=fyKKW&v6%L&%tgYt^&=q#08)Z z&{qWLYkK}Nl>Cm99J|nmF&;pZJFjbEWz7E}le-EZxj5I9F+`8!t0ZS319Oy+wTi4| zWhZ-Ny$b-+P@ch*!4Z)m@EVI)1EnN2y|$z~WU0M25Ku;39UvF;QxAa}<~U|P>Z^cA zLG{!^g$0b*V=PU{l+4I9DrLT;M`bW3r(($zd@~v~=!*b(n*jTN99Tf@w0nRNxKs_S z!+>+NAE8dxum$}B=mQ;q`?%7$cdhR_YmpcDjWAgEmfa9;A)h*84l=^IHz2M6pI|=# z$iKH2$Lx9)+)uTQVJ?C_u4Cq5GnRExU)ST0I2^ip7mVD3k9=R?{FdU|PVsGK_}*g7 zp69WSb$Y&9rh3!BL1QqcFOB5D_+^Hl^&0*(FWta4CY62 m*>1n*hF;h8qZZ=xs4}_bA}0__A+Hgh_obyJewIaABBE+ks>zjFUEQ)Pi!9rcuSVv9;nl^C}r~`~Km$8-~EI`TcB<*V(`Y`+*X!cXo>)>iw%LYRg4D8d^ zW9)wxKnmprOxZaivJ1RH5%N)Nsp(u>?~sLn&khcYNfaAa{dZ*n3j3A_XU>^n? zqx}SZvWE@mSHJ-109?=2=7Z}4&+W&47&N1B#b59ugoS+Si0R~nV{buR0Y1Tg1dxAk zFG+Y7Rd7GoHrzrCdsNSvfz7$>i~5H4KjLs0=D8ZVgBba~AowlAx0&Hv%kjN~&pj_< zb))z)A2P)ILGvzy_pG)#rvlsbxs~Jnv(J0_K;HE1SGSTfk1fQ~eLOj|Uek}LcWtS6oz5!_%`^SsD#&}_7NE~U*1*kvxYvB|bwKwA>%WkL zJsF{Y0r;HDuOY^KQ`7XA&8Iz$o#IR5t1M!b6FTppJ`7)$2Pvw~jMm&M{@cQS-0R2b z-f=zI{}}Vl{At+qxeoT}2P#eNO%nq)u#Na8ztZ^dfJ?_g0-aYWsm7 zERmLPs)WisN5juJd*Od+s0RG?96$Rt{8?VUeRF2y;(yc3>vlxG=XlOwe5jV=4SHVW hcfBBPV}21`rng+=1m;r6YlP>0d44`vjN1EW`45Kj{{sL3 literal 0 HcmV?d00001 diff --git a/js/react_native/e2e/android/app/src/main/java/com/reactnativeonnxruntimemodule/MNISTDataHandler.java b/js/react_native/e2e/android/app/src/main/java/com/reactnativeonnxruntimemodule/MNISTDataHandler.java index b5f58f39ea8ca..72f27b3291a8c 100644 --- a/js/react_native/e2e/android/app/src/main/java/com/reactnativeonnxruntimemodule/MNISTDataHandler.java +++ b/js/react_native/e2e/android/app/src/main/java/com/reactnativeonnxruntimemodule/MNISTDataHandler.java @@ -5,8 +5,6 @@ import static java.util.stream.Collectors.joining; -import ai.onnxruntime.reactnative.OnnxruntimeModule; -import ai.onnxruntime.reactnative.TensorHelper; import android.content.Context; import android.graphics.Bitmap; import android.graphics.BitmapFactory; @@ -149,7 +147,7 @@ private WritableMap preprocess(String uri) throws Exception { inputTensorMap.putArray("dims", dims); // type - inputTensorMap.putString("type", TensorHelper.JsTensorTypeFloat); + inputTensorMap.putString("type", "float32"); // data encoded as Base64 imageByteBuffer.rewind(); diff --git a/js/react_native/e2e/android/build.gradle b/js/react_native/e2e/android/build.gradle index 9ad8256fc52dc..8d1f9d59d8649 100644 --- a/js/react_native/e2e/android/build.gradle +++ b/js/react_native/e2e/android/build.gradle @@ -39,6 +39,10 @@ allprojects { // Add Detox as a precompiled native dependency url("$rootDir/../node_modules/detox/Detox-android") } + maven { + // Local onnxruntime-android package + url("$rootDir/app/libs") + } google() mavenCentral() @@ -46,4 +50,4 @@ allprojects { } } -apply plugin: "com.facebook.react.rootproject" \ No newline at end of file +apply plugin: "com.facebook.react.rootproject" diff --git a/js/react_native/e2e/android/gradle.properties b/js/react_native/e2e/android/gradle.properties index ede6147623f19..a8840c7b7d214 100644 --- a/js/react_native/e2e/android/gradle.properties +++ b/js/react_native/e2e/android/gradle.properties @@ -22,4 +22,6 @@ android.enableJetifier=true org.gradle.jvmargs=-Xmx8192m -XX:MaxMetaspaceSize=2048m -Dkotlin.daemon.jvm.options=-Xmx8192m # Use this property to enable or disable the Hermes JS engine. # If set to false, you will be using JSC instead. -hermesEnabled=false +hermesEnabled=true + +reactNativeArchitectures=x86_64 diff --git a/js/react_native/e2e/ios/MNISTDataHandler.h b/js/react_native/e2e/ios/MNISTDataHandler.h index da05843e8a41f..595cae82ea91c 100644 --- a/js/react_native/e2e/ios/MNISTDataHandler.h +++ b/js/react_native/e2e/ios/MNISTDataHandler.h @@ -4,7 +4,7 @@ #ifndef MNISTDataHandler_h #define MNISTDataHandler_h -#import +#import @interface MNISTDataHandler : NSObject @end diff --git a/js/react_native/e2e/ios/MNISTDataHandler.mm b/js/react_native/e2e/ios/MNISTDataHandler.mm index 1a79b66ca5d2f..6c27607eff1ed 100644 --- a/js/react_native/e2e/ios/MNISTDataHandler.mm +++ b/js/react_native/e2e/ios/MNISTDataHandler.mm @@ -2,10 +2,9 @@ // Licensed under the MIT License. #import "MNISTDataHandler.h" -#import "OnnxruntimeModule.h" -#import "TensorHelper.h" #import #import +#include NS_ASSUME_NONNULL_BEGIN @@ -119,7 +118,7 @@ - (NSDictionary*)preprocess:(NSString*)uri { inputTensorMap[@"dims"] = dims; // type - inputTensorMap[@"type"] = JsTensorTypeFloat; + inputTensorMap[@"type"] = @"float32"; // encoded data NSString* data = [byteBufferRef base64EncodedStringWithOptions:0]; diff --git a/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj b/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj index 6f957af603385..70a5fcdd33cad 100644 --- a/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj +++ b/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj @@ -10,6 +10,13 @@ 13B07FBC1A68108700A75B9A /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 13B07FB01A68108700A75B9A /* AppDelegate.m */; }; 13B07FBF1A68108700A75B9A /* Images.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 13B07FB51A68108700A75B9A /* Images.xcassets */; }; 13B07FC11A68108700A75B9A /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 13B07FB71A68108700A75B9A /* main.m */; }; + 3ADD0A3C2EBB64D200761D6F /* ../src/test_types_int8.ort in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A382EBB64D200761D6F /* ../src/test_types_int8.ort */; }; + 3ADD0A3D2EBB64D200761D6F /* ../src/test_types_int64.ort in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A3A2EBB64D200761D6F /* ../src/test_types_int64.ort */; }; + 3ADD0A3E2EBB64D200761D6F /* ../src/test_types_int32.ort in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A392EBB64D200761D6F /* ../src/test_types_int32.ort */; }; + 3ADD0A3F2EBB64D200761D6F /* ../src/test_types_float.ort in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A372EBB64D200761D6F /* ../src/test_types_float.ort */; }; + 3ADD0A402EBB64D200761D6F /* ../src/test_types_uint8.ort in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A3B2EBB64D200761D6F /* ../src/test_types_uint8.ort */; }; + 3ADD0A422EBB677300761D6F /* test_types_double.onnx in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A412EBB677300761D6F /* test_types_double.onnx */; }; + 3ADD0A442EBB679A00761D6F /* test_types_bool.onnx in Resources */ = {isa = PBXBuildFile; fileRef = 3ADD0A432EBB679A00761D6F /* test_types_bool.onnx */; }; 81AB9BB82411601600AC10FF /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 81AB9BB72411601600AC10FF /* LaunchScreen.storyboard */; }; DB61BA27278684FB0096C971 /* OnnxruntimeModuleExampleUITests.m in Sources */ = {isa = PBXBuildFile; fileRef = DB61BA26278684FB0096C971 /* OnnxruntimeModuleExampleUITests.m */; }; DBA8BA87267293C4008CC55A /* mnist.ort in Resources */ = {isa = PBXBuildFile; fileRef = DBA8BA86267293C4008CC55A /* mnist.ort */; }; @@ -50,6 +57,13 @@ 13B07FB51A68108700A75B9A /* Images.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; name = Images.xcassets; path = OnnxruntimeModuleExample/Images.xcassets; sourceTree = ""; }; 13B07FB61A68108700A75B9A /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; name = Info.plist; path = OnnxruntimeModuleExample/Info.plist; sourceTree = ""; }; 13B07FB71A68108700A75B9A /* main.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = main.m; path = OnnxruntimeModuleExample/main.m; sourceTree = ""; }; + 3ADD0A372EBB64D200761D6F /* ../src/test_types_float.ort */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_float.ort; sourceTree = ""; }; + 3ADD0A382EBB64D200761D6F /* ../src/test_types_int8.ort */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_int8.ort; sourceTree = ""; }; + 3ADD0A392EBB64D200761D6F /* ../src/test_types_int32.ort */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_int32.ort; sourceTree = ""; }; + 3ADD0A3A2EBB64D200761D6F /* ../src/test_types_int64.ort */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_int64.ort; sourceTree = ""; }; + 3ADD0A3B2EBB64D200761D6F /* ../src/test_types_uint8.ort */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_uint8.ort; sourceTree = ""; }; + 3ADD0A412EBB677300761D6F /* test_types_double.onnx */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_double.onnx; sourceTree = ""; }; + 3ADD0A432EBB679A00761D6F /* test_types_bool.onnx */ = {isa = PBXFileReference; lastKnownFileType = file; path = ../src/test_types_bool.onnx; sourceTree = ""; }; 81AB9BB72411601600AC10FF /* LaunchScreen.storyboard */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.storyboard; name = LaunchScreen.storyboard; path = OnnxruntimeModuleExample/LaunchScreen.storyboard; sourceTree = ""; }; 9D58C0FCCF00905433F4ED74 /* Pods-OnnxruntimeModuleExample.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModuleExample.debug.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModuleExample/Pods-OnnxruntimeModuleExample.debug.xcconfig"; sourceTree = ""; }; B70FCE6DFAB320E9051DA321 /* Pods-OnnxruntimeModuleExample.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModuleExample.release.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModuleExample/Pods-OnnxruntimeModuleExample.release.xcconfig"; sourceTree = ""; }; @@ -128,6 +142,13 @@ 83CBB9F61A601CBA00E9B192 = { isa = PBXGroup; children = ( + 3ADD0A432EBB679A00761D6F /* test_types_bool.onnx */, + 3ADD0A412EBB677300761D6F /* test_types_double.onnx */, + 3ADD0A372EBB64D200761D6F /* ../src/test_types_float.ort */, + 3ADD0A382EBB64D200761D6F /* ../src/test_types_int8.ort */, + 3ADD0A392EBB64D200761D6F /* ../src/test_types_int32.ort */, + 3ADD0A3A2EBB64D200761D6F /* ../src/test_types_int64.ort */, + 3ADD0A3B2EBB64D200761D6F /* ../src/test_types_uint8.ort */, DBA8BA86267293C4008CC55A /* mnist.ort */, DBBF7413263B8CCB00487C77 /* 3.jpg */, 13B07FAE1A68108700A75B9A /* OnnxruntimeModuleExample */, @@ -247,6 +268,13 @@ DBA8BA87267293C4008CC55A /* mnist.ort in Resources */, DBBF7414263B8CCB00487C77 /* 3.jpg in Resources */, 81AB9BB82411601600AC10FF /* LaunchScreen.storyboard in Resources */, + 3ADD0A422EBB677300761D6F /* test_types_double.onnx in Resources */, + 3ADD0A3C2EBB64D200761D6F /* ../src/test_types_int8.ort in Resources */, + 3ADD0A442EBB679A00761D6F /* test_types_bool.onnx in Resources */, + 3ADD0A3D2EBB64D200761D6F /* ../src/test_types_int64.ort in Resources */, + 3ADD0A3E2EBB64D200761D6F /* ../src/test_types_int32.ort in Resources */, + 3ADD0A3F2EBB64D200761D6F /* ../src/test_types_float.ort in Resources */, + 3ADD0A402EBB64D200761D6F /* ../src/test_types_uint8.ort in Resources */, E329E1162D3728940016B599 /* PrivacyInfo.xcprivacy in Resources */, 13B07FBF1A68108700A75B9A /* Images.xcassets in Resources */, ); diff --git a/js/react_native/e2e/ios/test_types_bool.ort b/js/react_native/e2e/ios/test_types_bool.ort new file mode 100644 index 0000000000000000000000000000000000000000..ee955dcc6fe54489ee8459847db0b04d6fcd11a1 GIT binary patch literal 1824 zcmZ`(O=}ZT6ur^0jkVUGg9HgBiy&E)A)+m%xbPbbeguu;!i8ZvZ3mK>FizUIQ5RB* zxadN}U*Ot>KcR@Y5OL)~x^^X97AzgldGlVK4D`a8d*6Nco_pWd*dlWGUhDR_lw@3r z7+KJw@C$lpMC8ZcB(Y!@)U3$VZj#vAJ_KC`ZURSvx%(a8_2$lRpBcEgYpf%u0!^E^2-E@km`h*F4;G;0carwC4Sg8n9ccDb)a&4CSjz@O6%6dt z)^qHC7C;K+1x(pFBC-p-K@sv%Y^mv7TkntspGXq{WyIA1j?+5Ghh4)Q$E-(v4G9? zy}cyiSyaLOT-z8HV%VcPW(GE6*%$Q_s>4>=>vJwvu8z1ojt7zYP`DhiktA zo0@B`0<<4$UlSU`0{5B{X#a-xbN^BqlL@uTv?kBArzX>Tcue=gHw7saaU*#**_-L( ztLA1OynF0p3%+z8PY$iu^dstBTk2h>aiyU+(;u#aya#Rp`YdD(-0X*Y&G%jhbbqk^ z3pv=65sohapL6*&_?T~MnjW+Hw5PFCd}(}@MXYi{=N;6C;mh(MMb(+nntR26TiB0# z{W#q_t|$8+W4?)>hRre8!9M*!rK!DXV!{Tt;ok(fZ;ycnz&by12boLZm@5gRC8zf+ z4kLHjX~$l`IsWXjai`#N!>-fy-6(t-F2+vaF0Z;$&&BaS*7lt+2nJ3s@>YCr)rq6F zANau%Y5Ar~sLXRT{PeRI{+EVoV7#8=XTOF&%geWK&Wv3AZ<@GnN922s=M3hDYB}DZ j=S6@1NyAwNn2P literal 0 HcmV?d00001 diff --git a/js/react_native/e2e/ios/test_types_double.ort b/js/react_native/e2e/ios/test_types_double.ort new file mode 100644 index 0000000000000000000000000000000000000000..0259d0eae66abd210a2f8f9eb0b2f73b5b2dad6b GIT binary patch literal 1824 zcmZ`(O=}cE5UtfQSvN5*I!KTu>|v2Tgdsr_Lh#@>2K-ob6%QUzcJ zWeC}Z@l{f=ke4|&5L3ZOowx|p0ooW#Thk8~G05*Et!oG4QOtKSvYw*e2Y16U(JkdIxK91J_ zk&^1Eh5RMV*kgodWJ0E75|uJvy5Dgyr>A1c1Z-0pHH?=4@-~4sZ~~Y^?bN$~5tvjB ztiyn7)SsYE*072370?IT0QYmX@!)#jb9=EL28}3O^%uPeW+9(CVoq|xwYMOy0H2^g z0?5C=pCmksD!89(9n(S#eOUXh5;&?;PKjLs`=D8ZV0~`6i!1yi0x0T^r&+)y( zoINjMcfSkU|kRl6RA} znL55|Zq~uO$C|cbOZ#|gV85mwQSaJP?>hA>5BQn-a24b|a0}39A?x5~J?u5#dmYgA zLHjS{U`<9izW{vB#n)hCyvb>5%*IonL8tiA_$rIo<(Q5;tPaDMG*CGJ`I;*CvaERT&d^$_#bQeP8bAzryF^zzPIMY zQOghfV41XhQ)Oi4IU0W2SquM51356?$nmpY!=L5l+c#%MF8()-U#BhdJ;!qv>jSwQ luiy0|zvBgQ3;sn^nci}dlklaG*9g!1^1?!}6txb{@*l#I{}liL literal 0 HcmV?d00001 diff --git a/js/react_native/e2e/metro.config.js b/js/react_native/e2e/metro.config.js index 9f279f35616a3..e9ef3a02f075a 100644 --- a/js/react_native/e2e/metro.config.js +++ b/js/react_native/e2e/metro.config.js @@ -12,6 +12,7 @@ const config = { ], resolver: { sourceExts: ['tsx', 'ts', 'jsx', 'js', 'json'], // Ensure TypeScript files are recognized + assetExts: ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'tiff', 'ico', 'webp', 'svg', 'ort', 'onnx'], }, }; module.exports = mergeConfig(getDefaultConfig(__dirname), config); diff --git a/js/react_native/e2e/src/App.tsx b/js/react_native/e2e/src/App.tsx index 39c496062f665..438073f864c42 100644 --- a/js/react_native/e2e/src/App.tsx +++ b/js/react_native/e2e/src/App.tsx @@ -2,22 +2,65 @@ // Licensed under the MIT License. import * as React from 'react'; -import { Image, Text, TextInput, View } from 'react-native'; -// onnxruntime-react-native package is installed when bootstraping -import { InferenceSession, Tensor } from 'onnxruntime-react-native'; -import MNIST, { MNISTInput, MNISTOutput, MNISTResult, } from './mnist-data-handler'; -import { Buffer } from 'buffer'; -import { readFile } from 'react-native-fs'; +import { Button, SafeAreaView, ScrollView, StyleSheet, Text, View } from 'react-native'; +import MNISTTest from './MNISTTest'; +import BasicTypesTest from './BasicTypesTest'; + +type Page = 'home' | 'mnist' | 'basic-types'; interface State { - session: - InferenceSession | null; - output: - string | null; - imagePath: - string | null; + currentPage: Page; } +const styles = StyleSheet.create({ + container: { + flex: 1, + backgroundColor: '#f5f5f5', + }, + scrollContent: { + padding: 20, + alignItems: 'center', + }, + title: { + fontSize: 28, + fontWeight: 'bold', + marginTop: 20, + marginBottom: 10, + color: '#333', + textAlign: 'center', + }, + subtitle: { + fontSize: 18, + marginBottom: 30, + color: '#666', + textAlign: 'center', + }, + buttonContainer: { + width: '100%', + marginBottom: 30, + alignItems: 'center', + }, + buttonWrapper: { + width: '80%', + marginBottom: 10, + }, + description: { + fontSize: 14, + color: '#888', + textAlign: 'center', + paddingHorizontal: 20, + }, + header: { + padding: 10, + backgroundColor: '#fff', + borderBottomWidth: 1, + borderBottomColor: '#ddd', + }, + testContent: { + flex: 1, + }, +}); + // eslint-disable-next-line @typescript-eslint/no-empty-object-type export default class App extends React.PureComponent<{}, State> { // eslint-disable-next-line @typescript-eslint/no-empty-object-type @@ -25,104 +68,89 @@ export default class App extends React.PureComponent<{}, State> { super(props); this.state = { - session: null, - output: null, - imagePath: null, + currentPage: 'home', }; } - // Load a model when an app is loading - async componentDidMount(): Promise { - if (!this.state.session) { - try { - const imagePath = await MNIST.getImagePath(); - this.setState({ imagePath }); - - const modelPath = await MNIST.getLocalModelPath(); - - // test creating session with path - console.log('Creating with path'); - const pathSession: InferenceSession = await InferenceSession.create(modelPath); - void pathSession.release(); - - // and with bytes - console.log('Creating with bytes'); - const base64Str = await readFile(modelPath, 'base64'); - const bytes = Buffer.from(base64Str, 'base64'); - const session: InferenceSession = await InferenceSession.create(bytes); - this.setState({ session }); - - console.log('Test session created'); - void await this.infer(); - } catch (err) { - console.log(err.message); - } - } + navigateTo = (page: Page) => { + this.setState({ currentPage: page }); + }; + + renderHome(): React.JSX.Element { + return ( + + + ONNX Runtime E2E Tests + Select a test to run: + + + +