Skip to content

Commit

Permalink
fix: Throw better errors
Browse files Browse the repository at this point in the history
  • Loading branch information
mrousavy committed Jan 19, 2024
1 parent ac3daf7 commit 2e27f13
Showing 1 changed file with 84 additions and 54 deletions.
138 changes: 84 additions & 54 deletions cpp/TensorflowPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,67 +65,74 @@ void TensorflowPlugin::installToRuntime(jsi::Runtime& runtime,
Promise::createPromise(runtime, [=, &runtime](std::shared_ptr<Promise> promise) {
// Launch async thread
std::async(std::launch::async, [=, &runtime]() {
// Fetch model from URL (JS bundle)
Buffer buffer = fetchURL(modelPath);

// Load Model into Tensorflow
auto model = TfLiteModelCreate(buffer.data, buffer.size);
if (model == nullptr) {
callInvoker->invokeAsync([=]() {
promise->reject("Failed to load model from \"" + modelPath + "\"!");
});
return;
}

// Create TensorFlow Interpreter
auto options = TfLiteInterpreterOptionsCreate();

switch (delegateType) {
case Delegate::CoreML: {
try {
// Fetch model from URL (JS bundle)
Buffer buffer = fetchURL(modelPath);

// Load Model into Tensorflow
auto model = TfLiteModelCreate(buffer.data, buffer.size);
if (model == nullptr) {
callInvoker->invokeAsync([=]() {
promise->reject("Failed to load model from \"" + modelPath + "\"!");
});
return;
}

// Create TensorFlow Interpreter
auto options = TfLiteInterpreterOptionsCreate();

switch (delegateType) {
case Delegate::CoreML: {
#if FAST_TFLITE_ENABLE_CORE_ML
TfLiteCoreMlDelegateOptions delegateOptions;
auto delegate = TfLiteCoreMlDelegateCreate(&delegateOptions);
TfLiteInterpreterOptionsAddDelegate(options, delegate);
break;
TfLiteCoreMlDelegateOptions delegateOptions;
auto delegate = TfLiteCoreMlDelegateCreate(&delegateOptions);
TfLiteInterpreterOptionsAddDelegate(options, delegate);
break;
#else
callInvoker->invokeAsync([=]() {
promise->reject("CoreML Delegate is not enabled! Set $EnableCoreMLDelegate to true in Podfile and rebuild.");
});
return;
callInvoker->invokeAsync([=]() {
promise->reject("CoreML Delegate is not enabled! Set $EnableCoreMLDelegate to true in Podfile and rebuild.");
});
return;
#endif
}
case Delegate::Metal: {
callInvoker->invokeAsync(
[=]() { promise->reject("Metal Delegate is not supported!"); });
return;
}
default: {
// use default CPU delegate.
}
}
case Delegate::Metal: {
callInvoker->invokeAsync(
[=]() { promise->reject("Metal Delegate is not supported!"); });

auto interpreter = TfLiteInterpreterCreate(model, options);

if (interpreter == nullptr) {
callInvoker->invokeAsync([=]() {
promise->reject("Failed to create TFLite interpreter from model \"" +
modelPath + "\"!");
});
return;
}
default: {
// use default CPU delegate.
}
}

auto interpreter = TfLiteInterpreterCreate(model, options);

if (interpreter == nullptr) {

// Initialize Model and allocate memory buffers
auto plugin = std::make_shared<TensorflowPlugin>(interpreter, buffer, delegateType,
callInvoker);

callInvoker->invokeAsync([=, &runtime]() {
auto result = jsi::Object::createFromHostObject(runtime, plugin);
promise->resolve(std::move(result));
});

auto end = std::chrono::steady_clock::now();
log("Successfully loaded Tensorflow Model in %i ms!",
std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count());
} catch (std::exception& error) {
std::string message = error.what();
callInvoker->invokeAsync([=]() {
promise->reject("Failed to create TFLite interpreter from model \"" +
modelPath + "\"!");
promise->reject(message);
});
return;
}

// Initialize Model and allocate memory buffers
auto plugin = std::make_shared<TensorflowPlugin>(interpreter, buffer, delegateType,
callInvoker);

callInvoker->invokeAsync([=, &runtime]() {
auto result = jsi::Object::createFromHostObject(runtime, plugin);
promise->resolve(std::move(result));
});

auto end = std::chrono::steady_clock::now();
log("Successfully loaded Tensorflow Model in %i ms!",
std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count());
});
});
return promise;
Expand All @@ -134,6 +141,29 @@ void TensorflowPlugin::installToRuntime(jsi::Runtime& runtime,
runtime.global().setProperty(runtime, "__loadTensorflowModel", func);
}

std::string tfLiteStatusToString(TfLiteStatus status) {
switch (status) {
case kTfLiteOk:
return "ok";
case kTfLiteError:
return "error";
case kTfLiteDelegateError:
return "delegate-error";
case kTfLiteApplicationError:
return "application-error";
case kTfLiteDelegateDataNotFound:
return "delegate-data-not-found";
case kTfLiteDelegateDataWriteError:
return "delegate-data-write-error";
case kTfLiteDelegateDataReadError:
return "delegate-data-read-error";
case kTfLiteUnresolvedOps:
return "unresolved-ops";
case kTfLiteCancelled:
return "cancelled";
}
}

TensorflowPlugin::TensorflowPlugin(TfLiteInterpreter* interpreter, Buffer model, Delegate delegate,
std::shared_ptr<react::CallInvoker> callInvoker)
: _interpreter(interpreter), _delegate(delegate), _model(model), _callInvoker(callInvoker) {
Expand All @@ -142,7 +172,7 @@ TensorflowPlugin::TensorflowPlugin(TfLiteInterpreter* interpreter, Buffer model,
if (status != kTfLiteOk) {
[[unlikely]];
throw std::runtime_error("Failed to allocate memory for input/output tensors! Status: " +
std::to_string(status));
tfLiteStatusToString(status));
}

log("Successfully created Tensorflow Plugin!");
Expand Down Expand Up @@ -206,7 +236,7 @@ void TensorflowPlugin::run() {
TfLiteStatus status = TfLiteInterpreterInvoke(_interpreter);
if (status != kTfLiteOk) {
[[unlikely]];
throw std::runtime_error("Failed to run TFLite Model! Status: " + std::to_string(status));
throw std::runtime_error("Failed to run TFLite Model! Status: " + tfLiteStatusToString(status));
}
}

Expand Down

0 comments on commit 2e27f13

Please sign in to comment.