In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from os import cpu_count
from collections import OrderedDict

import numpy as np

import onnx
from onnx import helper
import onnxruntime as ort
from onnxruntime_extensions import get_library_path

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text

In [3]:
MODEL_PATH_ONNX = 'models/universal-sentence-encoder-multilingual-large-3.onnx'
MODEL_PATH_TF = 'models/universal-sentence-encoder-multilingual-large-3'

In [4]:
def load_onnx_model(model_filepath):
    _options = ort.SessionOptions()
    _options.inter_op_num_threads, _options.intra_op_num_threads = cpu_count(), cpu_count()
    _options.register_custom_ops_library(get_library_path())
    _providers = ["CPUExecutionProvider"]  # could use ort.get_available_providers()
    return ort.InferenceSession(path_or_bytes=model_filepath, sess_options=_options, providers=_providers)

In [5]:
model_onnx_runtime = load_onnx_model(MODEL_PATH_ONNX)
model_onnx_runtime

<onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x357065790>

In [6]:
sentences = ["Hello, how are you?"]
output_onnx = model_onnx_runtime.run(output_names=["outputs"], input_feed={"inputs": sentences})

output_onnx[0][0]

array([ 2.77999556e-03,  2.53025256e-02, -1.34592364e-02, -6.00954220e-02,
       -3.57863074e-03,  2.16413494e-02,  1.99971776e-02, -3.70096900e-02,
       -2.91313455e-02, -1.07441895e-01,  2.73036715e-02, -8.28249753e-02,
        1.23869767e-02,  2.55314987e-02, -3.11783608e-02,  7.67881498e-02,
        5.88866062e-02,  9.00683925e-02, -2.02288944e-02,  5.06051704e-02,
        4.59435508e-02, -9.70152542e-02,  1.44991148e-02,  1.12705873e-02,
       -2.17566243e-03, -3.42216855e-03,  7.06949756e-02, -3.22969854e-02,
       -2.33398993e-02, -2.30249614e-02,  1.44220721e-02, -1.73140299e-02,
       -6.50571808e-02, -6.18350208e-02,  7.21207634e-02, -6.45894557e-02,
        2.22161356e-02,  8.38197710e-04, -1.77222826e-02, -4.84013781e-02,
        4.19012941e-02, -1.89781915e-02, -1.09328300e-01,  4.43844795e-02,
        1.54845342e-02,  4.25968096e-02,  4.46535386e-02,  7.47142211e-02,
        2.83761304e-02,  2.99820192e-02, -1.81293953e-02,  1.09258546e-02,
       -2.19645035e-02,  

### Validate

In [7]:
model_tf = tf.saved_model.load(MODEL_PATH_TF)

In [8]:
output_tf = model_tf(sentences)

output_tf.numpy()[0]

array([ 2.77998135e-03,  2.53025256e-02, -1.34592326e-02, -6.00954033e-02,
       -3.57862748e-03,  2.16413625e-02,  1.99971721e-02, -3.70096639e-02,
       -2.91313529e-02, -1.07441902e-01,  2.73036622e-02, -8.28249753e-02,
        1.23869926e-02,  2.55314838e-02, -3.11783515e-02,  7.67881498e-02,
        5.88866174e-02,  9.00684148e-02, -2.02289019e-02,  5.06051891e-02,
        4.59435731e-02, -9.70152542e-02,  1.44991055e-02,  1.12706097e-02,
       -2.17566569e-03, -3.42217949e-03,  7.06949830e-02, -3.22969817e-02,
       -2.33398993e-02, -2.30249688e-02,  1.44220600e-02, -1.73140336e-02,
       -6.50571808e-02, -6.18350431e-02,  7.21207708e-02, -6.45894706e-02,
        2.22161412e-02,  8.38220119e-04, -1.77222937e-02, -4.84013595e-02,
        4.19013053e-02, -1.89781878e-02, -1.09328344e-01,  4.43844795e-02,
        1.54845491e-02,  4.25968394e-02,  4.46535647e-02,  7.47142434e-02,
        2.83761024e-02,  2.99820378e-02, -1.81293916e-02,  1.09258704e-02,
       -2.19645463e-02,  

In [9]:
np.allclose(output_onnx, output_tf, atol=1e-7)

True

### Debug

In [10]:
model_onnx = onnx.load(MODEL_PATH_ONNX)

In [11]:
dir(model_onnx)

['ByteSize',
 'Clear',
 'ClearExtension',
 'ClearField',
 'CopyFrom',
 'DESCRIPTOR',
 'DOC_STRING_FIELD_NUMBER',
 'DOMAIN_FIELD_NUMBER',
 'DiscardUnknownFields',
 'FUNCTIONS_FIELD_NUMBER',
 'FindInitializationErrors',
 'FromString',
 'GRAPH_FIELD_NUMBER',
 'HasExtension',
 'HasField',
 'IR_VERSION_FIELD_NUMBER',
 'IsInitialized',
 'ListFields',
 'METADATA_PROPS_FIELD_NUMBER',
 'MODEL_VERSION_FIELD_NUMBER',
 'MergeFrom',
 'MergeFromString',
 'OPSET_IMPORT_FIELD_NUMBER',
 'PRODUCER_NAME_FIELD_NUMBER',
 'PRODUCER_VERSION_FIELD_NUMBER',
 'ParseFromString',
 'RegisterExtension',
 'SerializePartialToString',
 'SerializeToString',
 'SetInParent',
 'TRAINING_INFO_FIELD_NUMBER',
 'UnknownFields',
 'WhichOneof',
 '_InternalParse',
 '_InternalSerialize',
 '_Modified',
 '_SetListener',
 '_UpdateOneofState',
 '__class__',
 '__deepcopy__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_s

In [12]:
type(model_onnx.graph)

onnx.onnx_ml_pb2.GraphProto

In [13]:
model_onnx.graph.node

[input: "inputs"
input: "const_axes__298"
input: "const_fold_opt__1488"
input: "const_fold_opt__1460"
input: "const_fold_opt__1460"
input: "const_fold_opt__1497"
output: "StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp:0"
output: "StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp:1"
name: "StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp"
op_type: "SentencepieceTokenizer"
attribute {
  name: "Tsplits"
  i: 7
  type: INT
}
attribute {
  name: "out_type"
  i: 6
  type: INT
}
attribute {
  name: "return_nbest"
  i: 0
  type: INT
}
attribute {
  name: "model"
  type: STRING
}
domain: "ai.onnx.contrib"
, input: "StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp:0"
output: "StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Shape:0"
name: "StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUniqu

In [14]:
# add all intermediate outputs to onnx net
ort_session = model_onnx_runtime
org_outputs = [x.name for x in ort_session.get_outputs()]

model = model_onnx
node_outputs = []
for node in model.graph.node:
    for output in node.output:
        if output not in org_outputs:
            node_outputs.append(output)

In [15]:
node_outputs

['StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp:0',
 'StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp:1',
 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Shape:0',
 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Shape__199:0',
 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Reshape:0',
 'Slice__25:0',
 'Slice__20:0',
 'Sub__27:0',
 'Shape__33:0',
 'ReduceMax__31:0',
 'StatefulPartitionedCall/text_preprocessor_1/RaggedToSparse/RaggedTensorToSparse_Concat__64:0',
 'StatefulPartitionedCall/EncoderTransformer/Transformer/SparseTransformerEncode/Layer_5/SelfAttention/SparseMultiheadAttention/ComputeQKV/concat:0',
 'StatefulPartitionedCall/EncoderTransformer/Transformer/AttentionPooling/Shape/Cast:0',
 'StatefulPartitionedCall/EncoderTransformer/Transformer/AttentionPooling/strided_slice:0',
 'StatefulPartitione

In [16]:
node_outputs = [
    node for node in node_outputs if \
        ('text_preprocessor_1' in node) or ('EmbeddingLookup' in node)
]

node_outputs += [
    'StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp:0',
    'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/embedding_lookup/floordiv:0',
    'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Unique__208_cast:0',
    'Concat__524:0',
    'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Reshape_1:0',
    'Cast__220:0',
]

node_outputs

['StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp:0',
 'StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp:1',
 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Shape:0',
 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Shape__199:0',
 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Reshape:0',
 'StatefulPartitionedCall/text_preprocessor_1/RaggedToSparse/RaggedTensorToSparse_Concat__64:0',
 'StatefulPartitionedCall/text_preprocessor_1/RaggedToSparse/RaggedTensorToSparse_Concat__61:0',
 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Unique__210_cast:0',
 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Unique__208_cast:0',
 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/embedding_lookup/floordiv:0',
 'StatefulParti

In [17]:
model.graph.output.extend([onnx.ValueInfoProto(name=x) for x in node_outputs])

In [18]:
model.graph.output

[name: "outputs"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "unk__1551"
      }
      dim {
        dim_value: 512
      }
    }
  }
}
, name: "StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp:0"
, name: "StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp:1"
, name: "StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Shape:0"
, name: "StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Shape__199:0"
, name: "StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Reshape:0"
, name: "StatefulPartitionedCall/text_preprocessor_1/RaggedToSparse/RaggedTensorToSparse_Concat__64:0"
, name: "StatefulPartitionedCall/text_preprocessor_1/RaggedToSparse/RaggedTensorToSparse_Concat__61:0"
, name: "StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Unique__210_cast:0"
, name: "

In [19]:
# excute onnx
ort_session = load_onnx_model(model.SerializeToString())
outputs = [x.name for x in ort_session.get_outputs()]

ort_outs = ort_session.run(output_names=outputs, input_feed={"inputs": sentences})
ort_outs = OrderedDict(zip(outputs, ort_outs))

In [20]:
ort_outs

OrderedDict([('outputs',
              array([[ 2.77999556e-03,  2.53025256e-02, -1.34592364e-02,
                      -6.00954220e-02, -3.57863074e-03,  2.16413494e-02,
                       1.99971776e-02, -3.70096900e-02, -2.91313455e-02,
                      -1.07441895e-01,  2.73036715e-02, -8.28249753e-02,
                       1.23869767e-02,  2.55314987e-02, -3.11783608e-02,
                       7.67881498e-02,  5.88866062e-02,  9.00683925e-02,
                      -2.02288944e-02,  5.06051704e-02,  4.59435508e-02,
                      -9.70152542e-02,  1.44991148e-02,  1.12705873e-02,
                      -2.17566243e-03, -3.42216855e-03,  7.06949756e-02,
                      -3.22969854e-02, -2.33398993e-02, -2.30249614e-02,
                       1.44220721e-02, -1.73140299e-02, -6.50571808e-02,
                      -6.18350208e-02,  7.21207634e-02, -6.45894557e-02,
                       2.22161356e-02,  8.38197710e-04, -1.77222826e-02,
                      -4.8

In [21]:
ort_outs.keys()

odict_keys(['outputs', 'StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp:0', 'StatefulPartitionedCall/text_preprocessor_1/SentenceTokenizer/SentencepieceTokenizeOp:1', 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Shape:0', 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Shape__199:0', 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Reshape:0', 'StatefulPartitionedCall/text_preprocessor_1/RaggedToSparse/RaggedTensorToSparse_Concat__64:0', 'StatefulPartitionedCall/text_preprocessor_1/RaggedToSparse/RaggedTensorToSparse_Concat__61:0', 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Unique__210_cast:0', 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/Unique__208_cast:0', 'StatefulPartitionedCall/EncoderTransformer/EmbeddingLookup/EmbeddingLookupUnique/embedding_lookup/floordiv:0', 'S

In [22]:
ort_outs['outputs']

array([[ 2.77999556e-03,  2.53025256e-02, -1.34592364e-02,
        -6.00954220e-02, -3.57863074e-03,  2.16413494e-02,
         1.99971776e-02, -3.70096900e-02, -2.91313455e-02,
        -1.07441895e-01,  2.73036715e-02, -8.28249753e-02,
         1.23869767e-02,  2.55314987e-02, -3.11783608e-02,
         7.67881498e-02,  5.88866062e-02,  9.00683925e-02,
        -2.02288944e-02,  5.06051704e-02,  4.59435508e-02,
        -9.70152542e-02,  1.44991148e-02,  1.12705873e-02,
        -2.17566243e-03, -3.42216855e-03,  7.06949756e-02,
        -3.22969854e-02, -2.33398993e-02, -2.30249614e-02,
         1.44220721e-02, -1.73140299e-02, -6.50571808e-02,
        -6.18350208e-02,  7.21207634e-02, -6.45894557e-02,
         2.22161356e-02,  8.38197710e-04, -1.77222826e-02,
        -4.84013781e-02,  4.19012941e-02, -1.89781915e-02,
        -1.09328300e-01,  4.43844795e-02,  1.54845342e-02,
         4.25968096e-02,  4.46535386e-02,  7.47142211e-02,
         2.83761304e-02,  2.99820192e-02, -1.81293953e-0