This is a notebook with different model conversion techniques

**Torch ↔ ONNX**

**ONNX → Tensorflow SavedModel**

**Tensorflow SavedModel → TFLite**

**JAX → TFLite**

**TFLite → ONNX**







## PyTorch function → PyTorch Module

In [2]:
import torch

class Lambda(torch.nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)

sq_torch_module = Lambda(lambda x: x**2)

## PyTorch model → ONNX

In [3]:
!pip install -U onnx --quiet
!pip install -U onnx-tf --quiet
!pip install -U onnxruntime --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.7/15.7 MB[0m [31m32.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.1/226.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m611.8/611.8 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.4/6.4 MB[0m [31m40.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [6]:
import torch.onnx

torch.onnx.export(
    sq_torch_module,
    torch.tensor([1,2,3], dtype=torch.float32),
    f="temp.onnx",
    verbose=True,
    export_params=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': [0],
        'output': [0]
    }
)

Run model in ONNX runtime

In [5]:
import onnxruntime
import numpy as np

ort_session = onnxruntime.InferenceSession("temp.onnx", providers=["CPUExecutionProvider"])

ort_inputs = {'input': np.array([3,4,5,6], dtype=np.float32)}
ort_outs = ort_session.run(None, ort_inputs)

ort_outs

[array([ 9., 16., 25., 36.], dtype=float32)]

## ONNX → Tensorflow Model

In [9]:
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf

onnx_model = onnx.load("./temp.onnx")
tf_sq_rep = prepare(onnx_model)

In [10]:
tf_sq_rep.run(tf.constant([2,4,6,8,3], dtype=tf.float32)) # running in ONNX runtime

Outputs(output=array([ 4., 16., 36., 64.,  9.], dtype=float32))

In [11]:
tf_sq_rep.export_graph('./temp.pb') # save to tensorflow model

In [12]:
tf_sq_saved_model = tf.saved_model.load('./temp.pb') # load tensorflow model

In [13]:
tf_sq_saved_model.signatures['serving_default'](input=tf.constant([1,2,3,4], dtype=tf.float32)) # inference

{'output': <tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 1.,  4.,  9., 16.], dtype=float32)>}

## Tensorflow → TFLite

In [15]:
converter = tf.lite.TFLiteConverter.from_saved_model(
    'temp_fixed.pb'
)

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]

tf_sq_lite_model = converter.convert()

In [18]:
input_tensor = np.array([5,6,7,8], dtype=np.float32)

interpreter = tf.lite.Interpreter(model_content=tf_sq_lite_model)
input = interpreter.get_input_details()[0]
interpreter.resize_tensor_input(input['index'], input_tensor.shape)
output = interpreter.get_output_details()[0]
interpreter.allocate_tensors()
interpreter.invoke()
interpreter.set_tensor(input['index'], input_tensor)
interpreter.invoke()
interpreter.get_tensor(output['index'])

array([25., 36., 49., 64.], dtype=float32)

## JAX → TFLite



In [22]:
import jax.numpy as jnp
import jax

In [45]:
@jax.jit
def jax_sq(x):
  return x**2

In [46]:
input_size = (3,)

converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [jax_sq],
    [[('x', jnp.zeros(input_size))]]
)

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]

tf_jax_sq_lite_model = converter.convert()

In [47]:
input_tensor = np.array([7,6,5], dtype=np.float32)

interpreter = tf.lite.Interpreter(model_content=tf_jax_sq_lite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_tensor)
interpreter.invoke()
interpreter.get_tensor(output_details[0]["index"])

array([49., 36., 25.], dtype=float32)

In [54]:
with open('temp.tflite', 'wb') as f:
  f.write(tf_jax_sq_lite_model)

## TFLite → ONNX

In [51]:
!pip install tf2onnx --quiet

In [64]:
import tf2onnx

tf2onnx.convert.from_tflite(
    tflite_path='temp.tflite',
    output_path='temp.onnx'
)
None

In [65]:
ort_session = onnxruntime.InferenceSession("temp.onnx", providers=["CPUExecutionProvider"])

ort_inputs = {'x': np.array([6,5,4], dtype=np.float32)}
ort_outs = ort_session.run(None, ort_inputs)

ort_outs

[array([36., 25., 16.], dtype=float32)]

## ONNX → PyTorch

In [67]:
!pip install onnx2torch --quiet

In [70]:
import onnx2torch

torch_onnx_sq_module = onnx2torch.convert('temp.onnx')

In [73]:
torch_onnx_sq_module(torch.tensor([1,9,3,4]))

tensor([ 1, 81,  9, 16])