diff --git a/larq_compute_engine/tflite/python/interpreter.py b/larq_compute_engine/tflite/python/interpreter.py index 84682c99..e508af93 100644 --- a/larq_compute_engine/tflite/python/interpreter.py +++ b/larq_compute_engine/tflite/python/interpreter.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Tuple, Union +from typing import Iterator, List, Tuple, Union, Optional import numpy as np from tqdm import tqdm @@ -78,6 +78,16 @@ def input_shapes(self) -> List[Tuple[int]]: """Returns a list of input shapes.""" return self.interpreter.input_shapes + @property + def input_scales(self) -> List[Optional[Union[float, List[float]]]]: + """Returns a list of input scales.""" + return self.interpreter.input_scales + + @property + def input_zero_points(self) -> List[Optional[int]]: + """Returns a list of input zero points.""" + return self.interpreter.input_zero_points + @property def output_types(self) -> list: """Returns a list of output types.""" @@ -88,6 +98,16 @@ def output_shapes(self) -> List[Tuple[int]]: """Returns a list of output shapes.""" return self.interpreter.output_shapes + @property + def output_scales(self) -> List[Optional[Union[float, List[float]]]]: + """Returns a list of input scales.""" + return self.interpreter.output_scales + + @property + def output_zero_points(self) -> List[Optional[int]]: + """Returns a list of input zero points.""" + return self.interpreter.output_zero_points + def predict(self, x: Union[Data, Iterator[Data]], verbose: int = 0) -> Data: """Generates output predictions for the input samples. diff --git a/larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc b/larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc index bc7aaef3..7a67952c 100644 --- a/larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc +++ b/larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc @@ -60,5 +60,13 @@ PYBIND11_MODULE(interpreter_wrapper_lite, m) { nullptr) .def_property("output_shapes", &LiteInterpreterWrapper::get_output_shapes, nullptr) + .def_property("input_zero_points", + &LiteInterpreterWrapper::get_input_zero_points, nullptr) + .def_property("output_zero_points", + &LiteInterpreterWrapper::get_output_zero_points, nullptr) + .def_property("input_scales", &LiteInterpreterWrapper::get_input_scales, + nullptr) + .def_property("output_scales", &LiteInterpreterWrapper::get_output_scales, + nullptr) .def("predict", &LiteInterpreterWrapper::predict); }; diff --git a/larq_compute_engine/tflite/python/interpreter_wrapper_utils.h b/larq_compute_engine/tflite/python/interpreter_wrapper_utils.h index 28491ac4..a8fd5dc0 100644 --- a/larq_compute_engine/tflite/python/interpreter_wrapper_utils.h +++ b/larq_compute_engine/tflite/python/interpreter_wrapper_utils.h @@ -35,30 +35,62 @@ class InterpreterWrapperBase { pybind11::list predict(const pybind11::list& input_list); // List of numpy types - pybind11::list get_input_types() const { + pybind11::list get_input_types() { MINIMAL_CHECK(interpreter_); return get_types(interpreter_->inputs()); } - pybind11::list get_output_types() const { + pybind11::list get_output_types() { MINIMAL_CHECK(interpreter_); return get_types(interpreter_->outputs()); } // List of shape tuples - pybind11::list get_input_shapes() const { + pybind11::list get_input_shapes() { MINIMAL_CHECK(interpreter_); return get_shapes(interpreter_->inputs()); } - pybind11::list get_output_shapes() const { + pybind11::list get_output_shapes() { MINIMAL_CHECK(interpreter_); return get_shapes(interpreter_->outputs()); } + // List of zero points, None for non-quantized tensors + pybind11::list get_input_zero_points() { + MINIMAL_CHECK(interpreter_); + return get_zero_points(interpreter_->inputs()); + } + pybind11::list get_output_zero_points() { + MINIMAL_CHECK(interpreter_); + return get_zero_points(interpreter_->outputs()); + } + // List of quantization scales, None for non-quantized tensors + pybind11::list get_input_scales() { + MINIMAL_CHECK(interpreter_); + return get_scales(interpreter_->inputs()); + } + pybind11::list get_output_scales() { + MINIMAL_CHECK(interpreter_); + return get_scales(interpreter_->outputs()); + } protected: + // Calls to MicroInterpreter::tensor allocate memory, so we must cache them + TfLiteTensor* get_tensor(size_t index) { + auto iter = tensors.find(index); + if (iter != tensors.end()) return iter->second; + TfLiteTensor* tensor = interpreter_->tensor(index); + tensors[index] = tensor; + return tensor; + } + std::unique_ptr interpreter_; + std::map tensors; + template + pybind11::list get_types(const TensorList& tensors); template - pybind11::list get_types(const TensorList& tensors) const; + pybind11::list get_shapes(const TensorList& tensors); template - pybind11::list get_shapes(const TensorList& tensors) const; + pybind11::list get_zero_points(const TensorList& tensors); + template + pybind11::list get_scales(const TensorList& tensors); }; TfLiteType TfLiteTypeFromPyType(pybind11::dtype py_type) { @@ -139,11 +171,11 @@ bool SetTensorFromNumpy(const TfLiteTensor* tensor, template template pybind11::list InterpreterWrapperBase::get_types( - const TensorList& tensors) const { + const TensorList& tensors) { pybind11::list result; for (auto tensor_id : tensors) { - const TfLiteTensor* tensor = interpreter_->tensor(tensor_id); + const TfLiteTensor* tensor = get_tensor(tensor_id); result.append(PyTypeFromTfLiteType(tensor->type)); } @@ -153,11 +185,11 @@ pybind11::list InterpreterWrapperBase::get_types( template template pybind11::list InterpreterWrapperBase::get_shapes( - const TensorList& tensors) const { + const TensorList& tensors) { pybind11::list result; for (auto tensor_id : tensors) { - const TfLiteTensor* tensor = interpreter_->tensor(tensor_id); + const TfLiteTensor* tensor = get_tensor(tensor_id); pybind11::tuple shape(tensor->dims->size); for (int j = 0; j < tensor->dims->size; ++j) shape[j] = tensor->dims->data[j]; @@ -167,6 +199,74 @@ pybind11::list InterpreterWrapperBase::get_shapes( return result; } +template +template +pybind11::list InterpreterWrapperBase::get_zero_points( + const TensorList& tensors) { + pybind11::list result; + + for (auto tensor_id : tensors) { + const TfLiteTensor* tensor = get_tensor(tensor_id); + + if (tensor->quantization.type == kTfLiteAffineQuantization) { + const int legacy_zero_point = tensor->params.zero_point; + + const auto* affine_quantization = + reinterpret_cast( + tensor->quantization.params); + MINIMAL_CHECK(affine_quantization); + MINIMAL_CHECK(affine_quantization->zero_point); + + // For per-channel quantization, the zero point should be the same for + // every channel + for (int i = 0; i < affine_quantization->zero_point->size; ++i) + MINIMAL_CHECK(affine_quantization->zero_point->data[i] == + legacy_zero_point); + + result.append(pybind11::cast(legacy_zero_point)); + } else { + result.append(pybind11::cast(Py_None)); + } + } + + return result; +} + +template +template +pybind11::list InterpreterWrapperBase::get_scales( + const TensorList& tensors) { + pybind11::list result; + + for (auto tensor_id : tensors) { + const TfLiteTensor* tensor = get_tensor(tensor_id); + + if (tensor->quantization.type == kTfLiteAffineQuantization) { + const float legacy_scale = tensor->params.scale; + + const auto* affine_quantization = + reinterpret_cast( + tensor->quantization.params); + MINIMAL_CHECK(affine_quantization); + MINIMAL_CHECK(affine_quantization->scale); + + if (affine_quantization->scale->size == 1) { + MINIMAL_CHECK(affine_quantization->scale->data[0] == legacy_scale); + result.append(pybind11::cast(legacy_scale)); + } else { + std::vector scales; + for (int i = 0; i < affine_quantization->scale->size; ++i) + scales.push_back(affine_quantization->scale->data[i]); + result.append(pybind11::cast(scales)); + } + } else { + result.append(pybind11::cast(Py_None)); + } + } + + return result; +} + template pybind11::list InterpreterWrapperBase::predict( const pybind11::list& input_list) { @@ -181,8 +281,7 @@ pybind11::list InterpreterWrapperBase::predict( for (size_t i = 0; i < inputs_size; ++i) { pybind11::array nparray = pybind11::array::ensure(input_list[i], pybind11::array::c_style); - const TfLiteTensor* tensor = - interpreter_->tensor(interpreter_->inputs()[i]); + const TfLiteTensor* tensor = get_tensor(interpreter_->inputs()[i]); if (!SetTensorFromNumpy(tensor, nparray)) { PY_ERROR("Failed to set tensor data of input " << i); } @@ -192,7 +291,7 @@ pybind11::list InterpreterWrapperBase::predict( pybind11::list result; for (auto output_id : interpreter_->outputs()) { - TfLiteTensor* tensor = interpreter_->tensor(output_id); + TfLiteTensor* tensor = get_tensor(output_id); std::vector shape(tensor->dims->data, tensor->dims->data + tensor->dims->size); pybind11::array nparray(PyTypeFromTfLiteType(tensor->type), shape,