Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion larq_compute_engine/tflite/python/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterator, List, Tuple, Union
from typing import Iterator, List, Tuple, Union, Optional

import numpy as np
from tqdm import tqdm
Expand Down Expand Up @@ -78,6 +78,16 @@ def input_shapes(self) -> List[Tuple[int]]:
"""Returns a list of input shapes."""
return self.interpreter.input_shapes

@property
def input_scales(self) -> List[Optional[Union[float, List[float]]]]:
"""Returns a list of input scales."""
return self.interpreter.input_scales

@property
def input_zero_points(self) -> List[Optional[int]]:
"""Returns a list of input zero points."""
return self.interpreter.input_zero_points

@property
def output_types(self) -> list:
"""Returns a list of output types."""
Expand All @@ -88,6 +98,16 @@ def output_shapes(self) -> List[Tuple[int]]:
"""Returns a list of output shapes."""
return self.interpreter.output_shapes

@property
def output_scales(self) -> List[Optional[Union[float, List[float]]]]:
"""Returns a list of input scales."""
return self.interpreter.output_scales

@property
def output_zero_points(self) -> List[Optional[int]]:
"""Returns a list of input zero points."""
return self.interpreter.output_zero_points

def predict(self, x: Union[Data, Iterator[Data]], verbose: int = 0) -> Data:
"""Generates output predictions for the input samples.

Expand Down
8 changes: 8 additions & 0 deletions larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,13 @@ PYBIND11_MODULE(interpreter_wrapper_lite, m) {
nullptr)
.def_property("output_shapes", &LiteInterpreterWrapper::get_output_shapes,
nullptr)
.def_property("input_zero_points",
&LiteInterpreterWrapper::get_input_zero_points, nullptr)
.def_property("output_zero_points",
&LiteInterpreterWrapper::get_output_zero_points, nullptr)
.def_property("input_scales", &LiteInterpreterWrapper::get_input_scales,
nullptr)
.def_property("output_scales", &LiteInterpreterWrapper::get_output_scales,
nullptr)
.def("predict", &LiteInterpreterWrapper::predict);
};
125 changes: 112 additions & 13 deletions larq_compute_engine/tflite/python/interpreter_wrapper_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,62 @@ class InterpreterWrapperBase {
pybind11::list predict(const pybind11::list& input_list);

// List of numpy types
pybind11::list get_input_types() const {
pybind11::list get_input_types() {
MINIMAL_CHECK(interpreter_);
return get_types(interpreter_->inputs());
}
pybind11::list get_output_types() const {
pybind11::list get_output_types() {
MINIMAL_CHECK(interpreter_);
return get_types(interpreter_->outputs());
}
// List of shape tuples
pybind11::list get_input_shapes() const {
pybind11::list get_input_shapes() {
MINIMAL_CHECK(interpreter_);
return get_shapes(interpreter_->inputs());
}
pybind11::list get_output_shapes() const {
pybind11::list get_output_shapes() {
MINIMAL_CHECK(interpreter_);
return get_shapes(interpreter_->outputs());
}
// List of zero points, None for non-quantized tensors
pybind11::list get_input_zero_points() {
MINIMAL_CHECK(interpreter_);
return get_zero_points(interpreter_->inputs());
}
pybind11::list get_output_zero_points() {
MINIMAL_CHECK(interpreter_);
return get_zero_points(interpreter_->outputs());
}
// List of quantization scales, None for non-quantized tensors
pybind11::list get_input_scales() {
MINIMAL_CHECK(interpreter_);
return get_scales(interpreter_->inputs());
}
pybind11::list get_output_scales() {
MINIMAL_CHECK(interpreter_);
return get_scales(interpreter_->outputs());
}

protected:
// Calls to MicroInterpreter::tensor allocate memory, so we must cache them
TfLiteTensor* get_tensor(size_t index) {
auto iter = tensors.find(index);
if (iter != tensors.end()) return iter->second;
TfLiteTensor* tensor = interpreter_->tensor(index);
tensors[index] = tensor;
return tensor;
}

std::unique_ptr<InterpreterType> interpreter_;
std::map<int, TfLiteTensor*> tensors;
template <typename TensorList>
pybind11::list get_types(const TensorList& tensors);
template <typename TensorList>
pybind11::list get_types(const TensorList& tensors) const;
pybind11::list get_shapes(const TensorList& tensors);
template <typename TensorList>
pybind11::list get_shapes(const TensorList& tensors) const;
pybind11::list get_zero_points(const TensorList& tensors);
template <typename TensorList>
pybind11::list get_scales(const TensorList& tensors);
};

TfLiteType TfLiteTypeFromPyType(pybind11::dtype py_type) {
Expand Down Expand Up @@ -139,11 +171,11 @@ bool SetTensorFromNumpy(const TfLiteTensor* tensor,
template <typename InterpreterType>
template <typename TensorList>
pybind11::list InterpreterWrapperBase<InterpreterType>::get_types(
const TensorList& tensors) const {
const TensorList& tensors) {
pybind11::list result;

for (auto tensor_id : tensors) {
const TfLiteTensor* tensor = interpreter_->tensor(tensor_id);
const TfLiteTensor* tensor = get_tensor(tensor_id);
result.append(PyTypeFromTfLiteType(tensor->type));
}

Expand All @@ -153,11 +185,11 @@ pybind11::list InterpreterWrapperBase<InterpreterType>::get_types(
template <typename InterpreterType>
template <typename TensorList>
pybind11::list InterpreterWrapperBase<InterpreterType>::get_shapes(
const TensorList& tensors) const {
const TensorList& tensors) {
pybind11::list result;

for (auto tensor_id : tensors) {
const TfLiteTensor* tensor = interpreter_->tensor(tensor_id);
const TfLiteTensor* tensor = get_tensor(tensor_id);
pybind11::tuple shape(tensor->dims->size);
for (int j = 0; j < tensor->dims->size; ++j)
shape[j] = tensor->dims->data[j];
Expand All @@ -167,6 +199,74 @@ pybind11::list InterpreterWrapperBase<InterpreterType>::get_shapes(
return result;
}

template <typename InterpreterType>
template <typename TensorList>
pybind11::list InterpreterWrapperBase<InterpreterType>::get_zero_points(
const TensorList& tensors) {
pybind11::list result;

for (auto tensor_id : tensors) {
const TfLiteTensor* tensor = get_tensor(tensor_id);

if (tensor->quantization.type == kTfLiteAffineQuantization) {
const int legacy_zero_point = tensor->params.zero_point;

const auto* affine_quantization =
reinterpret_cast<TfLiteAffineQuantization*>(
tensor->quantization.params);
MINIMAL_CHECK(affine_quantization);
MINIMAL_CHECK(affine_quantization->zero_point);

// For per-channel quantization, the zero point should be the same for
// every channel
for (int i = 0; i < affine_quantization->zero_point->size; ++i)
MINIMAL_CHECK(affine_quantization->zero_point->data[i] ==
legacy_zero_point);

result.append(pybind11::cast(legacy_zero_point));
} else {
result.append(pybind11::cast<pybind11::none>(Py_None));
}
}

return result;
}

template <typename InterpreterType>
template <typename TensorList>
pybind11::list InterpreterWrapperBase<InterpreterType>::get_scales(
const TensorList& tensors) {
pybind11::list result;

for (auto tensor_id : tensors) {
const TfLiteTensor* tensor = get_tensor(tensor_id);

if (tensor->quantization.type == kTfLiteAffineQuantization) {
const float legacy_scale = tensor->params.scale;

const auto* affine_quantization =
reinterpret_cast<TfLiteAffineQuantization*>(
tensor->quantization.params);
MINIMAL_CHECK(affine_quantization);
MINIMAL_CHECK(affine_quantization->scale);

if (affine_quantization->scale->size == 1) {
MINIMAL_CHECK(affine_quantization->scale->data[0] == legacy_scale);
result.append(pybind11::cast(legacy_scale));
} else {
std::vector<float> scales;
for (int i = 0; i < affine_quantization->scale->size; ++i)
scales.push_back(affine_quantization->scale->data[i]);
result.append(pybind11::cast(scales));
}
} else {
result.append(pybind11::cast<pybind11::none>(Py_None));
}
}

return result;
}

template <typename InterpreterType>
pybind11::list InterpreterWrapperBase<InterpreterType>::predict(
const pybind11::list& input_list) {
Expand All @@ -181,8 +281,7 @@ pybind11::list InterpreterWrapperBase<InterpreterType>::predict(
for (size_t i = 0; i < inputs_size; ++i) {
pybind11::array nparray =
pybind11::array::ensure(input_list[i], pybind11::array::c_style);
const TfLiteTensor* tensor =
interpreter_->tensor(interpreter_->inputs()[i]);
const TfLiteTensor* tensor = get_tensor(interpreter_->inputs()[i]);
if (!SetTensorFromNumpy(tensor, nparray)) {
PY_ERROR("Failed to set tensor data of input " << i);
}
Expand All @@ -192,7 +291,7 @@ pybind11::list InterpreterWrapperBase<InterpreterType>::predict(

pybind11::list result;
for (auto output_id : interpreter_->outputs()) {
TfLiteTensor* tensor = interpreter_->tensor(output_id);
TfLiteTensor* tensor = get_tensor(output_id);
std::vector<int> shape(tensor->dims->data,
tensor->dims->data + tensor->dims->size);
pybind11::array nparray(PyTypeFromTfLiteType(tensor->type), shape,
Expand Down