This notebook shows how to convert the TF variant of a SegFormer model to ONNX for optimized inference on CPUs. 

## Installs

In [1]:
!pip install -Uqq tf2onnx
!pip install -Uqq onnxruntime
!pip install git+https://github.com/huggingface/transformers -q

[K     |████████████████████████████████| 440 kB 4.1 MB/s 
[K     |████████████████████████████████| 13.1 MB 49.6 MB/s 
[K     |████████████████████████████████| 4.9 MB 4.4 MB/s 
[K     |████████████████████████████████| 46 kB 284 kB/s 
[K     |████████████████████████████████| 86 kB 5.9 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 596 kB 4.4 MB/s 
[K     |████████████████████████████████| 101 kB 9.7 MB/s 
[K     |████████████████████████████████| 6.6 MB 29.6 MB/s 
[?25h  Building wheel for transformers (PEP 517) ... [?25l[?25hdone


## Imports

In [2]:
from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation

from PIL import Image
import numpy as np
import time

import onnx
import tf2onnx
import tensorflow as tf
import onnxruntime as ort

## Load model

You can know more about the SegFormer model from [here](https://huggingface.co/docs/transformers/main/en/model_doc/segformer). You can find all the pre-trained checkpoints of TensorFlow [here](https://huggingface.co/models?library=tf&other=segformer&sort=downloads).

In [3]:
model_ckpt = "nvidia/segformer-b5-finetuned-ade-640-640"

feature_extractor = SegformerFeatureExtractor.from_pretrained(model_ckpt)
model = TFSegformerForSemanticSegmentation.from_pretrained(model_ckpt)

Downloading preprocessor_config.json:   0%|          | 0.00/271 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/6.72k [00:00<?, ?B/s]

Downloading tf_model.h5:   0%|          | 0.00/325M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFSegformerForSemanticSegmentation.

All the layers of TFSegformerForSemanticSegmentation were initialized from the model checkpoint at nvidia/segformer-b5-finetuned-ade-640-640.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFSegformerForSemanticSegmentation for predictions without further training.


## ONNX conversion

In [4]:
input_size = feature_extractor.size
input_signature = [
    tf.TensorSpec([None, 3, input_size, input_size], tf.float32, name="pixel_values")
]
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=15)
onnx_model_path = model_ckpt.split("/")[-1] + ".onnx"
onnx.save(onnx_model, onnx_model_path)

Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`


In [5]:
!ls -lh {onnx_model_path}

-rw-r--r-- 1 root root 326M Jul 24 15:51 segformer-b5-finetuned-ade-640-640.onnx


## Compare ONNX predictions

In [6]:
dummy_inputs = tf.random.normal((1, 3, input_size, input_size))
dummy_inputs_numpy = dummy_inputs.numpy()

In [7]:
tf_outputs = model(dummy_inputs, training=False)

In [8]:
sess = ort.InferenceSession(onnx_model_path)
ort_outputs = sess.run(None, {"pixel_values": dummy_inputs_numpy})

In [9]:
list(tf_outputs.logits.shape) == list(ort_outputs[0].shape)

True

In [11]:
np.allclose(tf_outputs.logits.numpy(), ort_outputs, rtol=1e-5, atol=1e-05)

True

## Benchmarking speed

### TF model

In [12]:
# Warmup
print("Benchmarking TF model...")
for _ in range(2):
    _ = model(dummy_inputs, training=False)

# Timing
tf_outputs = []
start_time = time.time()
for _ in range(25):
    tf_outputs.append(model(dummy_inputs, training=False))
end_time = time.time()
print(f"Inference completed within {(end_time - start_time):.2f} seconds.")

Benchmarking TF model...
Inference completed within 628.32 seconds.


### ONNX model

In [13]:
for _ in range(2):
    _ = sess.run(None, {"pixel_values": dummy_inputs_numpy})

# Timing
ort_outputs = []
start_time = time.time()
for _ in range(25):
    ort_outputs.append(sess.run(None, {"pixel_values": dummy_inputs_numpy}))
end_time = time.time()
print(f"Inference completed within {(end_time - start_time):.2f} seconds.")

Inference completed within 250.47 seconds.
