-
Notifications
You must be signed in to change notification settings - Fork 34
/
lce_minimal.cc
62 lines (48 loc) · 2.05 KB
/
lce_minimal.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <cstdio>
#include "larq_compute_engine/tflite/kernels/lce_ops_register.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/optional_debug_tools.h"
// This file is based on the TF lite minimal example where the
// "BuiltinOpResolver" is modified to include the "Larq Compute Engine" custom
// ops. Here we read a binary model from disk and perform inference by using the
// C++ interface. See the BUILD file in this directory to see an example of
// linking "Larq Compute Engine" cutoms ops to your inference binary.
using namespace tflite;
#define TFLITE_MINIMAL_CHECK(x) \
if (!(x)) { \
fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
exit(1); \
}
int main(int argc, char* argv[]) {
if (argc != 2) {
fprintf(stderr, "lce_minimal <tflite model>\n");
return 1;
}
const char* filename = argv[1];
// Load model
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(filename);
TFLITE_MINIMAL_CHECK(model != nullptr);
// Build the interpreter
tflite::ops::builtin::BuiltinOpResolver resolver;
compute_engine::tflite::RegisterLCECustomOps(&resolver);
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> interpreter;
builder(&interpreter);
TFLITE_MINIMAL_CHECK(interpreter != nullptr);
// Allocate tensor buffers.
TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
printf("=== Pre-invoke Interpreter State ===\n");
tflite::PrintInterpreterState(interpreter.get());
// Fill input buffers
// TODO(user): Insert code to fill input tensors
// Run inference
TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
printf("\n\n=== Post-invoke Interpreter State ===\n");
tflite::PrintInterpreterState(interpreter.get());
// Read output buffers
// TODO(user): Insert getting data out code.
return 0;
}