Skip to content

Commit

Permalink
feat: support for android (#12)
Browse files Browse the repository at this point in the history
* feat: support for android

* chore: revert unnecessary changes
  • Loading branch information
jaroslawkrol committed Dec 21, 2023
1 parent ab68a41 commit 13db933
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 39 deletions.
23 changes: 11 additions & 12 deletions android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ set (CMAKE_CXX_STANDARD 17)
find_package(ReactAndroid REQUIRED CONFIG)
find_package(fbjni REQUIRED CONFIG)

# TODO: Fix linking TFLite
# find_library(
# TFLITE
# tensorflowlite_jni
# PATHS "./src/main/cpp/lib/tensorflow/jni/${ANDROID_ABI}"
# NO_DEFAULT_PATH
# NO_CMAKE_FIND_ROOT_PATH)
find_library(
TFLITE
tensorflowlite_jni
PATHS "./src/main/cpp/lib/tensorflow/jni/${ANDROID_ABI}"
NO_DEFAULT_PATH
NO_CMAKE_FIND_ROOT_PATH
)

string(APPEND CMAKE_CXX_FLAGS " -DANDROID")

Expand All @@ -23,9 +23,8 @@ add_library(
SHARED
../cpp/jsi/Promise.cpp
../cpp/jsi/TypedArray.cpp
# TODO: Uncomment this when tensorflow-lite C/C++ API can be successfully built/linked here
#../cpp/TensorflowPlugin.cpp
#../cpp/TensorHelpers.cpp
../cpp/TensorflowPlugin.cpp
../cpp/TensorHelpers.cpp
src/main/cpp/Tflite.cpp
)

Expand All @@ -35,7 +34,7 @@ target_include_directories(
PRIVATE
"../cpp"
"src/main/cpp"
# "src/main/cpp/lib/tensorflow/headers/"
"src/main/cpp/lib/tensorflow/headers/"
"${NODE_MODULES_DIR}/react-native/ReactCommon"
"${NODE_MODULES_DIR}/react-native/ReactCommon/callinvoker"
"${NODE_MODULES_DIR}/react-native/ReactAndroid/src/main/jni/react/turbomodule" # <-- CallInvokerHolder JNI wrapper
Expand All @@ -49,5 +48,5 @@ target_link_libraries(
ReactAndroid::jsi # <-- jsi.h
ReactAndroid::reactnativejni # <-- CallInvokerImpl
fbjni::fbjni # <-- fbjni.h
# ${TFLITE}
${TFLITE}
)
6 changes: 3 additions & 3 deletions android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ dependencies {
implementation "com.facebook.react:react-native:+"

// Tensorflow Lite .aar (includes C API via prefabs)
implementation "org.tensorflow:tensorflow-lite:2.13.0"
extractHeaders("org.tensorflow:tensorflow-lite:2.13.0")
extractSO("org.tensorflow:tensorflow-lite:2.13.0")
implementation "org.tensorflow:tensorflow-lite:2.12.0"
extractHeaders("org.tensorflow:tensorflow-lite:2.12.0")
extractSO("org.tensorflow:tensorflow-lite:2.12.0")
}

task extractAARHeaders {
Expand Down
51 changes: 30 additions & 21 deletions android/src/main/cpp/Tflite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
#include <jsi/jsi.h>
#include <memory>

// TODO: Uncomment this when tensorflow-lite C/C++ API can be successfully built/linked here
// #include "TensorflowPlugin.h"
#include "TensorflowPlugin.h"
#include <ReactCommon/CallInvoker.h>
#include <ReactCommon/CallInvokerHolder.h>

namespace mrousavy {

JavaVM *java_machine;

using namespace facebook;
using namespace facebook::jni;

// Java Insaller
// Java Installer
struct TfliteModule : public jni::JavaClass<TfliteModule> {
public:
static constexpr auto kJavaDescriptor = "Lcom/tflite/TfliteModule;";
Expand All @@ -29,37 +30,44 @@ struct TfliteModule : public jni::JavaClass<TfliteModule> {
}
auto jsCallInvoker = jsCallInvokerHolder->cthis()->getCallInvoker();

// TODO: Uncomment this when tensorflow-lite C/C++ API can be successfully built/linked here
/*auto fetchByteDataFromUrl = [](std::string url) {
auto fetchByteDataFromUrl = [](std::string url) {

// Attaching Current Thread to JVM
JNIEnv* env = nullptr;
int getEnvStat = java_machine->GetEnv((void**)&env, JNI_VERSION_1_6);
if (getEnvStat == JNI_EDETACHED) {
if (java_machine->AttachCurrentThread(&env, nullptr) != 0) {
throw std::runtime_error("Failed to attach thread to JVM");
}
}

static const auto cls = javaClassStatic();
static const auto method =
cls->getStaticMethod<jbyteArray(std::string)>("fetchByteDataFromUrl");
cls->getStaticMethod<jbyteArray(std::string)>("fetchByteDataFromUrl");

auto byteData = method(cls, url);

// TODO: to review by someone experienced much more in C++
// Detaching current thread causes app crash with exception:
// "Unable to retrieve jni environment. is the thread attached?"
// anyway, there is still a risk of memory leakage without calling the function below:

// java_machine->DetachCurrentThread();

auto size = byteData->size();
auto bytes = byteData->getRegion(0, size);
void* data = malloc(size);
memcpy(data, bytes.get(), size);

return Buffer {
.data = data,
.size = size
.data = data,
.size = size
};

};
*/

try {
// TODO: Uncomment this when tensorflow-lite C/C++ API can be successfully built/linked here
// TensorflowPlugin::installToRuntime(*runtime, jsCallInvoker, fetchByteDataFromUrl);

// TODO: Remove this when tensorflow-lite C/C++ API can be successfully built/linked here
auto func = jsi::Function::createFromHostFunction(
*runtime, jsi::PropNameID::forAscii(*runtime, "__loadTensorflowModel"), 1,
[=](jsi::Runtime& runtime, const jsi::Value& thisValue, const jsi::Value* arguments,
size_t count) -> jsi::Value {
throw jsi::JSError(runtime, "react-native-fast-tflite is not yet supported on Android! "
"I couldn't manage to get TFLite to build for NDK/C++ :/");
});
runtime->global().setProperty(*runtime, "__loadTensorflowModel", func);
TensorflowPlugin::installToRuntime(*runtime, jsCallInvoker, fetchByteDataFromUrl);
} catch (std::exception& exc) {
return false;
}
Expand All @@ -77,5 +85,6 @@ struct TfliteModule : public jni::JavaClass<TfliteModule> {
} // namespace mrousavy

JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
java_machine = vm;
return facebook::jni::initialize(vm, [] { mrousavy::TfliteModule::registerNatives(); });
}
5 changes: 2 additions & 3 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import * as React from 'react'

import { StyleSheet, View, Text } from 'react-native'
import { StyleSheet, View, Text, Platform } from 'react-native'
import {
loadTensorflowModel,
useTensorflowModel,
} from 'react-native-fast-tflite'

Expand All @@ -11,7 +10,7 @@ export default function App() {

const model = useTensorflowModel(
require('../assets/object_detection_mobile_object_localizer_v1_1_default_1.tflite'),
'core-ml'
Platform.OS === 'ios' ? 'core-ml' : 'default'
)

React.useEffect(() => {
Expand Down

0 comments on commit 13db933

Please sign in to comment.