Skip to content

Commit

Permalink
feat: Add CoreML delegate support
Browse files Browse the repository at this point in the history
  • Loading branch information
mrousavy committed Aug 21, 2023
1 parent e3af880 commit bb2a4d9
Show file tree
Hide file tree
Showing 14 changed files with 186 additions and 39 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ import { multiply } from 'vision-camera-tflite';
const result = await multiply(3, 7);
```

### Using GPU Delegates

GPU Delegates offer faster, GPU accelerated computation. There's multiple different GPU delegates which you can enable:

#### CoreML (iOS)

To enable the CoreML Delegate, you need to include the CoreML/Metal code in your project:

1. Set `$EnableCoreMLDelegate` to true in your `Podfile`:
```ruby
$EnableCoreMLDelegate=true

# rest of your podfile...
```
2. Open your iOS project in Xcode and add the `CoreML` framework to your project:
![Xcode > xcodeproj > General > Frameworks, Libraries and Embedded Content > CoreML](ios/../img/ios-coreml-guide.png)


## Contributing

See the [contributing guide](CONTRIBUTING.md) to learn how to contribute to the repository and the development workflow.
Expand Down
85 changes: 54 additions & 31 deletions cpp/TensorflowPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
#include <iostream>
#include <future>

#if VISION_CAMERA_TFLITE_ENABLE_CORE_ML
#include <TensorFlowLiteCCoreML/TensorFlowLiteCCoreML.h>
#endif

using namespace facebook;
using namespace mrousavy;

Expand All @@ -28,8 +32,8 @@ void log(std::string string...) {
void TensorflowPlugin::installToRuntime(jsi::Runtime& runtime,
std::shared_ptr<react::CallInvoker> callInvoker,
FetchURLFunc fetchURL) {


auto func = jsi::Function::createFromHostFunction(runtime,
jsi::PropNameID::forAscii(runtime, "__loadTensorflowModel"),
1,
Expand All @@ -39,38 +43,30 @@ void TensorflowPlugin::installToRuntime(jsi::Runtime& runtime,
size_t count) -> jsi::Value {
auto start = std::chrono::steady_clock::now();
auto modelPath = arguments[0].asString(runtime).utf8(runtime);

log("Loading TensorFlow Lite Model from \"%s\"...", modelPath.c_str());

// TODO: Figure out how to use Metal/CoreML delegates
Delegate delegate = Delegate::Default;
/*
auto delegates = [[NSMutableArray alloc] init];
Delegate delegateType = Delegate::Default;
if (count > 1 && arguments[1].isString()) {
// user passed a custom delegate command
auto delegate = arguments[1].asString(runtime).utf8(runtime);
if (delegate == "core-ml") {
NSLog(@"Using CoreML delegate.");
[delegates addObject:[[TFLCoreMLDelegate alloc] init]];
delegate = Delegate::CoreML;
delegateType = Delegate::CoreML;
} else if (delegate == "metal") {
NSLog(@"Using Metal delegate.");
[delegates addObject:[[TFLMetalDelegate alloc] init]];
delegate = Delegate::Metal;
delegateType = Delegate::Metal;
} else {
NSLog(@"Using standard CPU delegate.");
delegate = Delegate::Default;
delegateType = Delegate::Default;
}
}
*/


auto promise = Promise::createPromise(runtime,
[=, &runtime](std::shared_ptr<Promise> promise) {
// Launch async thread
std::async(std::launch::async, [=, &runtime]() {
// Fetch model from URL (JS bundle)
Buffer buffer = fetchURL(modelPath);

// Load Model into Tensorflow
auto model = TfLiteModelCreate(buffer.data, buffer.size);
if (model == nullptr) {
Expand All @@ -79,33 +75,60 @@ void TensorflowPlugin::installToRuntime(jsi::Runtime& runtime,
});
return;
}

// Create TensorFlow Interpreter
auto options = TfLiteInterpreterOptionsCreate();

switch (delegateType) {
case Delegate::CoreML: {
#if VISION_CAMERA_TFLITE_ENABLE_CORE_ML
TfLiteCoreMlDelegateOptions delegateOptions;
auto delegate = TfLiteCoreMlDelegateCreate(&delegateOptions);
TfLiteInterpreterOptionsAddDelegate(options, delegate);
break;
#else
callInvoker->invokeAsync([=]() {
promise->reject("CoreML Delegate is not enabled! Set $EnableCoreMLDelegate to true in Podfile and rebuild.");
});
return;
#endif
}
case Delegate::Metal: {
callInvoker->invokeAsync([=]() {
promise->reject("Metal Delegate is not supported!");
});
return;
}
default: {
// use default CPU delegate.
}
}

auto interpreter = TfLiteInterpreterCreate(model, options);

if (interpreter == nullptr) {
callInvoker->invokeAsync([=]() {
promise->reject("Failed to create TFLite interpreter from model \"" + modelPath + "\"!");
});
return;
}

// Initialize Model and allocate memory buffers
auto plugin = std::make_shared<TensorflowPlugin>(interpreter, buffer, delegate, callInvoker);
auto plugin = std::make_shared<TensorflowPlugin>(interpreter, buffer, delegateType, callInvoker);

callInvoker->invokeAsync([=, &runtime]() {
auto result = jsi::Object::createFromHostObject(runtime, plugin);
promise->resolve(std::move(result));
});

auto end = std::chrono::steady_clock::now();
log("Successfully loaded Tensorflow Model in %i ms!",
std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count());
});
});
return promise;
});

runtime.global().setProperty(runtime, "__loadTensorflowModel", func);
}

Expand All @@ -120,7 +143,7 @@ TensorflowPlugin::TensorflowPlugin(TfLiteInterpreter* interpreter,
if (status != kTfLiteOk) {
throw std::runtime_error("Failed to allocate memory for input/output tensors! Status: " + std::to_string(status));
}

log("Successfully created Tensorflow Plugin!");
}

Expand Down Expand Up @@ -151,7 +174,7 @@ void TensorflowPlugin::copyInputBuffers(jsi::Runtime &runtime, jsi::Object input
if (count != TfLiteInterpreterGetInputTensorCount(_interpreter)) {
throw std::runtime_error("TFLite: Input Values have different size than there are input tensors!");
}

for (size_t i = 0; i < count; i++) {
TfLiteTensor* tensor = TfLiteInterpreterGetInputTensor(_interpreter, i);
auto value = array.getValueAtIndex(runtime, i);
Expand Down Expand Up @@ -184,7 +207,7 @@ void TensorflowPlugin::run() {

jsi::Value TensorflowPlugin::get(jsi::Runtime& runtime, const jsi::PropNameID& propNameId) {
auto propName = propNameId.utf8(runtime);

if (propName == "runSync") {
return jsi::Function::createFromHostFunction(runtime,
jsi::PropNameID::forAscii(runtime, "runModel"),
Expand Down Expand Up @@ -215,7 +238,7 @@ jsi::Value TensorflowPlugin::get(jsi::Runtime& runtime, const jsi::PropNameID& p
// 2.
try {
this->run();

this->_callInvoker->invokeAsync([=, &runtime]() {
// 3.
auto result = this->copyOutputBuffers(runtime);
Expand All @@ -236,7 +259,7 @@ jsi::Value TensorflowPlugin::get(jsi::Runtime& runtime, const jsi::PropNameID& p
if (tensor == nullptr) {
throw jsi::JSError(runtime, "Failed to get input tensor " + std::to_string(i) + "!");
}

jsi::Object object = TensorHelpers::tensorToJSObject(runtime, tensor);
tensors.setValueAtIndex(runtime, i, object);
}
Expand All @@ -249,7 +272,7 @@ jsi::Value TensorflowPlugin::get(jsi::Runtime& runtime, const jsi::PropNameID& p
if (tensor == nullptr) {
throw jsi::JSError(runtime, "Failed to get output tensor " + std::to_string(i) + "!");
}

jsi::Object object = TensorHelpers::tensorToJSObject(runtime, tensor);
tensors.setValueAtIndex(runtime, i, object);
}
Expand All @@ -264,7 +287,7 @@ jsi::Value TensorflowPlugin::get(jsi::Runtime& runtime, const jsi::PropNameID& p
return jsi::String::createFromUtf8(runtime, "metal");
}
}

return jsi::HostObject::get(runtime, propNameId);
}

Expand Down
2 changes: 2 additions & 0 deletions example/ios/Podfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ if linkage != nil
use_frameworks! :linkage => linkage.to_sym
end

$EnableCoreMLDelegate=true

target 'TfliteExample' do
config = use_native_modules!

Expand Down
4 changes: 2 additions & 2 deletions example/ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -710,10 +710,10 @@ SPEC CHECKSUMS:
React-utils: 0a70ea97d4e2749f336b450c082905be1d389435
ReactCommon: e593d19c9e271a6da4d0bd7f13b28cfeae5d164b
SocketRocket: f32cd54efbe0f095c4d7594881e52619cfe80b17
vision-camera-tflite: df63762f0c16d98f4b2f80785e6b5b349572a7e2
vision-camera-tflite: 0966a28ed0d2e0bf927b91558bb39183a670f5be
Yoga: 65286bb6a07edce5e4fe8c90774da977ae8fc009
YogaKit: f782866e155069a2cca2517aafea43200b01fd5a

PODFILE CHECKSUM: 73276d9fcf292db20cba4f8b0a4ddba7d81803e2
PODFILE CHECKSUM: 9ac9e7462894d8a61cb346a87c5aeb9996b11140

COCOAPODS: 1.12.1
6 changes: 6 additions & 0 deletions example/ios/TfliteExample.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
13B07FC11A68108700A75B9A /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 13B07FB71A68108700A75B9A /* main.m */; };
7699B88040F8A987B510C191 /* libPods-TfliteExample-TfliteExampleTests.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 19F6CBCC0A4E27FBF8BF4A61 /* libPods-TfliteExample-TfliteExampleTests.a */; };
81AB9BB82411601600AC10FF /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 81AB9BB72411601600AC10FF /* LaunchScreen.storyboard */; };
B84FBB7F2A93BC88008D281C /* CoreML.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = B84FBB7D2A93B607008D281C /* CoreML.framework */; };
/* End PBXBuildFile section */

/* Begin PBXContainerItemProxy section */
Expand Down Expand Up @@ -43,6 +44,8 @@
5DCACB8F33CDC322A6C60F78 /* libPods-TfliteExample.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-TfliteExample.a"; sourceTree = BUILT_PRODUCTS_DIR; };
81AB9BB72411601600AC10FF /* LaunchScreen.storyboard */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.storyboard; name = LaunchScreen.storyboard; path = TfliteExample/LaunchScreen.storyboard; sourceTree = "<group>"; };
89C6BE57DB24E9ADA2F236DE /* Pods-TfliteExample-TfliteExampleTests.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-TfliteExample-TfliteExampleTests.release.xcconfig"; path = "Target Support Files/Pods-TfliteExample-TfliteExampleTests/Pods-TfliteExample-TfliteExampleTests.release.xcconfig"; sourceTree = "<group>"; };
B84FBB7B2A93B603008D281C /* Metal.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Metal.framework; path = System/Library/Frameworks/Metal.framework; sourceTree = SDKROOT; };
B84FBB7D2A93B607008D281C /* CoreML.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreML.framework; path = System/Library/Frameworks/CoreML.framework; sourceTree = SDKROOT; };
ED297162215061F000B7C4FE /* JavaScriptCore.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = JavaScriptCore.framework; path = System/Library/Frameworks/JavaScriptCore.framework; sourceTree = SDKROOT; };
/* End PBXFileReference section */

Expand All @@ -59,6 +62,7 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
B84FBB7F2A93BC88008D281C /* CoreML.framework in Frameworks */,
0C80B921A6F3F58F76C31292 /* libPods-TfliteExample.a in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
Expand Down Expand Up @@ -99,6 +103,8 @@
2D16E6871FA4F8E400B85C8A /* Frameworks */ = {
isa = PBXGroup;
children = (
B84FBB7D2A93B607008D281C /* CoreML.framework */,
B84FBB7B2A93B603008D281C /* Metal.framework */,
ED297162215061F000B7C4FE /* JavaScriptCore.framework */,
5DCACB8F33CDC322A6C60F78 /* libPods-TfliteExample.a */,
19F6CBCC0A4E27FBF8BF4A61 /* libPods-TfliteExample-TfliteExampleTests.a */,
Expand Down
12 changes: 7 additions & 5 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@ import { StyleSheet, View, Text } from 'react-native';
import { loadTensorflowModel, useTensorflowModel } from 'vision-camera-tflite';

export default function App() {
const [result, setResult] = React.useState<number | undefined>();
const [result, setResult] = React.useState('');

const model = useTensorflowModel(
require('../assets/object_detection_mobile_object_localizer_v1_1_default_1.tflite')
require('../assets/object_detection_mobile_object_localizer_v1_1_default_1.tflite'),
'core-ml'
);

React.useEffect(() => {
if (model.model == null) return;

console.log(`Running Model...`);
const result = model.model.run([new Uint8Array([5])]);
result.then((result) => {
console.log(`Successfully ran Model!`, result);
const r = model.model.run([new Uint8Array([5])]);
r.then((output) => {
console.log(`Successfully ran Model!`, output);
setResult(`${output[0]}${output[1]}${output[2]}...`);
});
}, [model.model]);

Expand Down
Binary file added img/ios-coreml-guide.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#import <TensorFlowLiteCCoreML/coreml_delegate.h>
72 changes: 72 additions & 0 deletions ios/TensorFlowLiteCCoreML.framework/Headers/coreml_delegate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_COREML_COREML_DELEGATE_H_
#define TENSORFLOW_LITE_DELEGATES_COREML_COREML_DELEGATE_H_

#include "TensorFlowLiteC/common.h"

// LINT.IfChange

#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
typedef enum {
// Create Core ML delegate only on devices with Apple Neural Engine.
// Returns nullptr otherwise.
TfLiteCoreMlDelegateDevicesWithNeuralEngine,
// Always create Core ML delegate
TfLiteCoreMlDelegateAllDevices
} TfLiteCoreMlDelegateEnabledDevices;

typedef struct {
// Only create delegate when Neural Engine is available on the device.
TfLiteCoreMlDelegateEnabledDevices enabled_devices;
// Specifies target Core ML version for model conversion.
// Core ML 3 come with a lot more ops, but some ops (e.g. reshape) is not
// delegated due to input rank constraint.
// if not set to one of the valid versions, the delegate will use highest
// version possible in the platform.
// Valid versions: (2, 3)
int coreml_version;
// This sets the maximum number of Core ML delegates created.
// Each graph corresponds to one delegated node subset in the
// TFLite model. Set this to 0 to delegate all possible partitions.
int max_delegated_partitions;
// This sets the minimum number of nodes per partition delegated with
// Core ML delegate. Defaults to 2.
int min_nodes_per_partition;
#ifdef TFLITE_DEBUG_DELEGATE
// This sets the index of the first node that could be delegated.
int first_delegate_node_index;
// This sets the index of the last node that could be delegated.
int last_delegate_node_index;
#endif
} TfLiteCoreMlDelegateOptions;

// Return a delegate that uses CoreML for ops execution.
// Must outlive the interpreter.
TfLiteDelegate* TfLiteCoreMlDelegateCreate(
const TfLiteCoreMlDelegateOptions* options);

// Do any needed cleanup and delete 'delegate'.
void TfLiteCoreMlDelegateDelete(TfLiteDelegate* delegate);

#ifdef __cplusplus
}
#endif // __cplusplus

// LINT.ThenChange(README.md)

#endif // TENSORFLOW_LITE_DELEGATES_COREML_COREML_DELEGATE_H_
9 changes: 9 additions & 0 deletions ios/TensorFlowLiteCCoreML.framework/Modules/module.modulemap
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
framework module TensorFlowLiteCCoreML {
umbrella header "TensorFlowLiteCCoreML.h"
export *
module * { export * }
link "m"
link "pthread"
link framework "CoreML"
link framework "Foundation"
}
Binary file not shown.

0 comments on commit bb2a4d9

Please sign in to comment.