Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nano: fix keras onnx model output shape #7138

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -21,7 +21,7 @@
from tempfile import TemporaryDirectory
import tensorflow as tf
from bigdl.nano.utils.util import get_default_args
from bigdl.nano.tf.utils import KERAS_VERSION_LESS_2_10
from bigdl.nano.tf.utils import KERAS_VERSION_LESS_2_10, fake_tensor_from_spec
from bigdl.nano.utils.inference.tf.model import AcceleratedKerasModel
from bigdl.nano.utils.log4Error import invalidInputError

Expand Down Expand Up @@ -59,6 +59,14 @@ def __init__(self, model, input_spec=None, onnxruntime_session_options=None,
input_spec = (input_spec, )
tf2onnx.convert.from_keras(model, input_signature=input_spec,
output_path=onnx_path, **export_kwargs)
# save the nesting level of inference output
fake_inputs = [fake_tensor_from_spec(spec) for spec in input_spec]
fake_outputs = model(*fake_inputs)
self._nesting_level = 0
while isinstance(fake_outputs, (tuple, list)) and len(fake_outputs) == 1:
self._nesting_level += 1
fake_outputs = fake_outputs[0]

self._inputs_dtypes = [inp.dtype for inp in input_spec]
self._default_kwargs = get_default_args(model.call)
if KERAS_VERSION_LESS_2_10:
Expand All @@ -78,7 +86,10 @@ def __call__(self, *args, **kwargs):
kwargs[name] = value
for param in self._call_fn_args_backup[len(inputs):len(self._forward_args)]:
inputs.append(kwargs[param])
return self.call(*inputs)
outputs = self.call(*inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand right, here returns original onnx output(a list) ? so maybe it contains _nesting_level at first (not start from 0) ?
It's just a bit confusing whether this way of adding lists works in all cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, here outputs is a single Tensor, not onnx output.

Copy link
Contributor Author

@MeouSker77 MeouSker77 Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.call will call tf.py_function(func, ..., Tout=tf.float32), which will convert the output of func to a single Tensor

for _i in range(self._nesting_level):
outputs = [outputs]
return outputs

def on_forward_start(self, inputs):
if self.ortsess is None:
Expand All @@ -89,8 +100,6 @@ def on_forward_start(self, inputs):
return inputs

def on_forward_end(self, outputs):
if isinstance(outputs, list) and len(outputs) == 1:
outputs = outputs[0]
outputs = self.numpy_to_tensors(outputs)
MeouSker77 marked this conversation as resolved.
Show resolved Hide resolved
return outputs

Expand Down Expand Up @@ -143,7 +152,8 @@ def _save_model(self, path):
super()._save_model(onnx_path)
attrs = {"_default_kwargs": self._default_kwargs,
"_call_fn_args_backup": self._call_fn_args_backup,
"_inputs_dtypes": self._inputs_dtypes}
"_inputs_dtypes": self._inputs_dtypes,
"_nesting_level": self._nesting_level}
with open(Path(path) / self.status['attr_path'], "wb") as f:
pickle.dump(attrs, f)
if self._is_compiled:
Expand Down
1 change: 1 addition & 0 deletions python/nano/src/bigdl/nano/tf/keras/inference/optimizer.py
Expand Up @@ -593,6 +593,7 @@ def quantize(model: Model,
outputs=outputs,
onnx_option='tensorflow',
onnxruntime_session_options=onnxruntime_session_options)
result._nesting_level = onnx_model._nesting_level
result._inputs_dtypes = onnx_model._inputs_dtypes
result._default_kwargs = onnx_model._default_kwargs
result._call_fn_args_backup = onnx_model._call_fn_args_backup
Expand Down
12 changes: 12 additions & 0 deletions python/nano/src/bigdl/nano/tf/utils.py
Expand Up @@ -15,6 +15,7 @@
#
import inspect
import operator
import tensorflow as tf
from tensorflow.keras import Model
from functools import partial
from bigdl.nano.common.compare_version import _compare_version
Expand Down Expand Up @@ -96,3 +97,14 @@ def patch_compiled(target_model: Model, source_model: Model):
kwargs["weighted_metrics"] = source_model.compiled_metrics._user_weighted_metrics
target_model.compile(**kwargs)
return target_model


def fake_tensor_from_spec(tensor_spec: tf.TensorSpec):
"""Fake a `Tensor` from `TensorSpec`."""
shape = tensor_spec.shape
dtype = tensor_spec.dtype
shape = tuple(dim if dim is not None else 1 for dim in shape)
if shape == () and dtype == tf.bool:
# This may be the `training` parameter, we should assume it is False
return False
return tf.ones(shape=shape, dtype=dtype)
22 changes: 22 additions & 0 deletions python/nano/test/tf/keras/test_trace_and_quantize.py
Expand Up @@ -42,6 +42,15 @@ def do_nothing():
pass


class MyModelReturnList(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)

def call(self, inputs: tf.Tensor):
return [self.dense1(inputs)]


class TestTraceAndQuantize(TestCase):
def test_attribute_access_after_trace(self):
x = 100
Expand Down Expand Up @@ -99,3 +108,16 @@ def test_evaluate_after_trace(self):
InferenceOptimizer.save(traced_model, tmp_dir_name)
new_model = InferenceOptimizer.load(tmp_dir_name, model)
new_model.evaluate(x=x, y=y)

def test_inference_output_shape(self):
model = MyModelReturnList()
x = np.random.random((100, 4))
traced_model = InferenceOptimizer.trace(model, accelerator="onnxruntime",
input_spec=tf.TensorSpec(shape=(None, 4), dtype=tf.float32))
outputs = traced_model(x)
assert isinstance(outputs, list) and isinstance(outputs[0], tf.Tensor)

quantized_model = InferenceOptimizer.quantize(model, accelerator="onnxruntime",
input_spec=tf.TensorSpec(shape=(None, 4)), x=x)
outputs = quantized_model(x)
assert isinstance(outputs, list) and isinstance(outputs[0], tf.Tensor)