In [1]:
import cv2
import jax.numpy as jnp
import numpy as np
import onnx

from onnx_jax.backend import run_model


def _cosin_sim(a, b):
    a = a.flatten()
    b = b.flatten()
    cos_sim = jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
    return cos_sim


In [4]:
# https://github.com/onnx/models/blob/master/vision/classification/mobilenet/model/mobilenetv2-7.onnx
fp = '/home/allen/ml_data/models/onnx/mobilenetv2-7.onnx'
onnx_model = onnx.load_model(fp)
graph = onnx_model.graph

# input_shape is [1,3,224,244]
input_shape = [x.dim_value for x in graph.input[0].type.tensor_type.shape.dim]
print(f"input shape: {input_shape}")

# https://commons.wikimedia.org/wiki/File:Giant_Panda_in_Beijing_Zoo_1.JPG
# imagenet id of panda is 388
fp_img = '/home/allen/ml_data/data/Giant_Panda_in_Beijing_Zoo_1.jpeg'
input_ = cv2.imread(fp_img, 1)
input_ = cv2.resize(input_, (224, 224))
input_ = jnp.expand_dims(input_, 0)
input_ = jnp.transpose(input_, [0, 3, 1, 2])
input_ = input_.astype(jnp.float32)
input_ = input_ / 255.0
input_ = input_ - jnp.array([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1])
input_ = input_ / jnp.array([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1])

# run model
outputs = run_model(onnx_model, [input_])
print(f"onnx-jax result: {jnp.argmax(outputs[0])}")

try:
    import onnxruntime as ort
    sess = ort.InferenceSession(fp)
    sess.get_inputs()
    out = sess.run(None, {graph.input[0].name: np.asarray(input_)})
    print(f"onnxruntime reult: {np.argmax(out[0])}")

    # compare with onnxruntime reult
    sim = _cosin_sim(jnp.asarray(out[0]), outputs[0])
    print(f"Output tensor similarity with onnxruntime: {sim}")
except:
    pass


input shape: [1, 3, 224, 224]
running: Conv, mobilenetv20_features_conv0_fwd
running: BatchNormalization, mobilenetv20_features_batchnorm0_fwd
running: Relu, mobilenetv20_features_relu0_fwd
running: Conv, mobilenetv20_features_linearbottleneck0_conv0_fwd
running: BatchNormalization, mobilenetv20_features_linearbottleneck0_batchnorm0_fwd
running: Relu, mobilenetv20_features_linearbottleneck0_relu0_fwd
running: Conv, mobilenetv20_features_linearbottleneck0_conv1_fwd
running: BatchNormalization, mobilenetv20_features_linearbottleneck0_batchnorm1_fwd
running: Relu, mobilenetv20_features_linearbottleneck0_relu1_fwd
running: Conv, mobilenetv20_features_linearbottleneck0_conv2_fwd
running: BatchNormalization, mobilenetv20_features_linearbottleneck0_batchnorm2_fwd
running: Conv, mobilenetv20_features_linearbottleneck1_conv0_fwd
running: BatchNormalization, mobilenetv20_features_linearbottleneck1_batchnorm0_fwd
running: Relu, mobilenetv20_features_linearbottleneck1_relu0_fwd
running: Conv, mobi