Permalink
Cannot retrieve contributors at this time
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
62 lines (48 sloc)
2.05 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cstdio> | |
#include "larq_compute_engine/tflite/kernels/lce_ops_register.h" | |
#include "tensorflow/lite/interpreter.h" | |
#include "tensorflow/lite/kernels/register.h" | |
#include "tensorflow/lite/model.h" | |
#include "tensorflow/lite/optional_debug_tools.h" | |
// This file is based on the TF lite minimal example where the | |
// "BuiltinOpResolver" is modified to include the "Larq Compute Engine" custom | |
// ops. Here we read a binary model from disk and perform inference by using the | |
// C++ interface. See the BUILD file in this directory to see an example of | |
// linking "Larq Compute Engine" cutoms ops to your inference binary. | |
using namespace tflite; | |
#define TFLITE_MINIMAL_CHECK(x) \ | |
if (!(x)) { \ | |
fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ | |
exit(1); \ | |
} | |
int main(int argc, char* argv[]) { | |
if (argc != 2) { | |
fprintf(stderr, "lce_minimal <tflite model>\n"); | |
return 1; | |
} | |
const char* filename = argv[1]; | |
// Load model | |
std::unique_ptr<tflite::FlatBufferModel> model = | |
tflite::FlatBufferModel::BuildFromFile(filename); | |
TFLITE_MINIMAL_CHECK(model != nullptr); | |
// Build the interpreter | |
tflite::ops::builtin::BuiltinOpResolver resolver; | |
compute_engine::tflite::RegisterLCECustomOps(&resolver); | |
InterpreterBuilder builder(*model, resolver); | |
std::unique_ptr<Interpreter> interpreter; | |
builder(&interpreter); | |
TFLITE_MINIMAL_CHECK(interpreter != nullptr); | |
// Allocate tensor buffers. | |
TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk); | |
printf("=== Pre-invoke Interpreter State ===\n"); | |
tflite::PrintInterpreterState(interpreter.get()); | |
// Fill input buffers | |
// TODO(user): Insert code to fill input tensors | |
// Run inference | |
TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk); | |
printf("\n\n=== Post-invoke Interpreter State ===\n"); | |
tflite::PrintInterpreterState(interpreter.get()); | |
// Read output buffers | |
// TODO(user): Insert getting data out code. | |
return 0; | |
} |