Skip to content

Commit

Permalink
Optional ONNX import in onnxconversion
Browse files Browse the repository at this point in the history
  • Loading branch information
Felix-Mac committed Mar 28, 2023
1 parent a073cc8 commit c29a9ac
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions do_mpc/sysid/_onnxconversion.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import casadi
import onnx
import numpy as np
import pdb
import importlib


# Import optional packages

ONNX_INSTALLED = False

if importlib.util.find_spec("onnx"):
import onnx
ONNX_INSTALLED = True

class ONNXConversion:
""" Transform `ONNX model <https://onnx.ai>`_.
The transformation returns a CasADi expression of the model and can be used e.g. in the :py:class:`do_mpc.model.Model` class.
Expand Down Expand Up @@ -86,31 +94,17 @@ class ONNXConversion:
"""

def __init__(self, model, model_name=None, from_keras=False):
def __init__(self, model, model_name=None):
if not ONNX_INSTALLED:
raise Exception("The package 'onnx' is not installed. Please install it..")

# In case of a keras model as input, convert it to an ONNX model
if from_keras:
try:
import tf2onnx
import tensorflow as tf
except:
raise Exception("The package 'tf2onnx' or 'tensorflow' is not installed. Please install it..")

assert isinstance(model,(tf.keras.Model)), 'The input model is not a Keras model.'

model_input_signature = [tf.TensorSpec(np.array(self._determine_shape(inp_spec.shape)),
name=inp_spec.name) for inp_spec in model.input_spec]
self.onnx_model, _ = tf2onnx.convert.from_keras(model,output_path=None,
input_signature=model_input_signature)
self.name = "casadi_model" if not isinstance(model_name, (str)) else model.name

elif isinstance(model,(onnx.onnx_ml_pb2.ModelProto)):
if isinstance(model,(onnx.onnx_ml_pb2.ModelProto)):
self.onnx_model = model
self.name = "casadi_model" if not isinstance(model_name, (str)) else model_name

else:
raise Exception("Please pass a keras or onnx model as input. Please use the from_keras flag to convert a keras model to an onnx model.")



# From the ONNX model the graph and the nodes and the initializers are directly inherited
self.graph = self.onnx_model.graph
Expand Down

0 comments on commit c29a9ac

Please sign in to comment.