Skip to content

Commit

Permalink
fix: Validate input buffer type (#27)
Browse files Browse the repository at this point in the history
* fix: Also validate input buffer type

* fix: Use efficientdet
  • Loading branch information
mrousavy committed Jan 24, 2024
1 parent 99ccb41 commit e099d23
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 9 deletions.
40 changes: 38 additions & 2 deletions cpp/TensorHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,34 @@ std::string dataTypeToString(TfLiteType dataType) {
case kTfLiteInt4:
return "int4";
default:
[[unlikely]];
return "invalid";
}
}

TfLiteType getTFLDataTypeForTypedArrayKind(TypedArrayKind kind) {
switch (kind) {
case TypedArrayKind::Int8Array:
return kTfLiteInt8;
case TypedArrayKind::Int16Array:
return kTfLiteInt16;
case TypedArrayKind::Int32Array:
return kTfLiteInt32;
case TypedArrayKind::Uint8Array:
return kTfLiteUInt8;
case TypedArrayKind::Uint8ClampedArray:
return kTfLiteUInt8;
case TypedArrayKind::Uint16Array:
return kTfLiteUInt16;
case TypedArrayKind::Uint32Array:
return kTfLiteUInt32;
case TypedArrayKind::Float32Array:
return kTfLiteFloat32;
case TypedArrayKind::Float64Array:
return kTfLiteFloat64;
}
}

size_t TensorHelpers::getTFLTensorDataTypeSize(TfLiteType dataType) {
switch (dataType) {
case kTfLiteFloat32:
Expand Down Expand Up @@ -195,8 +219,20 @@ void TensorHelpers::updateJSBufferFromTensor(jsi::Runtime& runtime, TypedArrayBa

void TensorHelpers::updateTensorFromJSBuffer(jsi::Runtime& runtime, TfLiteTensor* tensor,
TypedArrayBase& jsBuffer) {
auto name = std::string(TfLiteTensorName(tensor));
auto buffer = jsBuffer.getBuffer(runtime);
#if DEBUG
TypedArrayKind kind = jsBuffer.getKind(runtime);
TfLiteType receivedType = getTFLDataTypeForTypedArrayKind(kind);
TfLiteType expectedType = TfLiteTensorType(tensor);
if (receivedType != expectedType) {
[[unlikely]];
throw std::runtime_error("Invalid input type! Model expected " +
dataTypeToString(expectedType) + ", but received " +
dataTypeToString(receivedType) + "!");
}
#endif

std::string name = TfLiteTensorName(tensor);
jsi::ArrayBuffer buffer = jsBuffer.getBuffer(runtime);

#if DEBUG
int inputBufferSize = buffer.size(runtime);
Expand Down
14 changes: 11 additions & 3 deletions cpp/TensorflowPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,15 @@ TensorflowPlugin::getOutputArrayForTensor(jsi::Runtime& runtime, const TfLiteTen

void TensorflowPlugin::copyInputBuffers(jsi::Runtime& runtime, jsi::Object inputValues) {
// Input has to be array in input tensor size
auto array = inputValues.asArray(runtime);
#if DEBUG
if (!inputValues.isArray(runtime)) {
[[unlikely]];
throw std::runtime_error(
"TFLite: Input Values must be an array, one item for each input tensor!");
}
#endif

jsi::Array array = inputValues.asArray(runtime);
size_t count = array.size(runtime);
if (count != TfLiteInterpreterGetInputTensorCount(_interpreter)) {
[[unlikely]];
Expand All @@ -212,8 +220,8 @@ void TensorflowPlugin::copyInputBuffers(jsi::Runtime& runtime, jsi::Object input

for (size_t i = 0; i < count; i++) {
TfLiteTensor* tensor = TfLiteInterpreterGetInputTensor(_interpreter, i);
auto value = array.getValueAtIndex(runtime, i);
auto inputBuffer = getTypedArray(runtime, value.asObject(runtime));
jsi::Value value = array.getValueAtIndex(runtime, i);
TypedArrayBase inputBuffer = getTypedArray(runtime, value.asObject(runtime));
TensorHelpers::updateTensorFromJSBuffer(runtime, tensor, inputBuffer);
}
}
Expand Down
Binary file added example/assets/efficientdet.tflite
Binary file not shown.
Binary file removed example/assets/object_detector.tflite
Binary file not shown.
Binary file removed example/assets/smartreply_1_default_1.tflite
Binary file not shown.
10 changes: 6 additions & 4 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ export default function App(): React.ReactNode {
const { hasPermission, requestPermission } = useCameraPermission()
const device = useCameraDevice('back')

const model = useTensorflowModel(require('../assets/object_detector.tflite'))
// from https://www.kaggle.com/models/tensorflow/efficientdet/frameworks/tfLite
const model = useTensorflowModel(require('../assets/efficientdet.tflite'))
const actualModel = model.state === 'loaded' ? model.model : undefined

React.useEffect(() => {
Expand All @@ -51,14 +52,15 @@ export default function App(): React.ReactNode {
console.log(`Running inference on ${frame}`)
const resized = resize(frame, {
size: {
width: 640,
height: 640,
width: 320,
height: 320,
},
pixelFormat: 'rgb-uint8',
})
const typedArray = new Uint8Array(resized)
const result = actualModel.runSync([typedArray])
console.log('Result: ' + result.length)
const num_detections = result[3]?.[0] ?? 0
console.log('Result: ' + num_detections)
},
[actualModel]
)
Expand Down

0 comments on commit e099d23

Please sign in to comment.