In [None]:
!pip install -U torch
!pip install -U clip
!pip install -U open_clip_torch
!pip install -U onnx
!pip install -U onnxsim
!pip install -U onnxscript
!pip uninstall -y onnxruntime
!pip uninstall -y onnxruntime-gpu
!pip install -U onnxruntime-gpu

In [None]:
# # !wget https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/resolve/main/open_clip_pytorch_model.bin?download=true
# !wget https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/resolve/main/open_clip_pytorch_model.bin?download=true
# # rename it back
# !mv open_clip_pytorch_model.bin?download=true open_clip_pytorch_model.bin

In [None]:
import onnxruntime as ort
# check gpu
# print("ONNX Runtime version:", ort.__version__)
print("Available providers:", ort.get_available_providers())

In [None]:
import os
import clip
import open_clip
import numpy as np
from PIL import Image

# onnx cannot export with cuda
# model, preprocess = clip.load("CLIP-ViT-L-14-DataComp.XL-s13B-b90K/open_clip_pytorch_model.bin", device="cpu", jit=False)
model_name = "ViT-L-14"
print(open_clip.list_pretrained(model_name))

pretrained = "/home/haoyu/projects/model2onnx/open_clip_pytorch_model.bin"
model, _, preprocess = open_clip.create_model_and_transforms(
    model_name=model_name,
    pretrained=pretrained,
)

In [None]:
image = preprocess(Image.open("/home/haoyu/projects/model2onnx/cat.jpg")).unsqueeze(0).cpu() # [1, 3, 224, 224]
image_onnx = image.detach().cpu().numpy().astype(np.float32)

In [None]:
import onnx
import torch
import numpy as np
import onnxruntime as ort
from onnxsim import simplify

dummy_input_shape = (4, 3, 224, 224)
dummy_input = torch.randn(dummy_input_shape, dtype=torch.float32)
onnx_path = "clip_model_raw.onnx"

torch.onnx.export(
    model.visual,
    dummy_input,
    onnx_path,
    opset_version=18,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)
print(f"Exported raw model to {onnx_path}")

In [None]:
onnx_model = onnx.load(onnx_path)
model_simplified, check = simplify(onnx_model)
print(model_simplified.graph.input[0])
print(model_simplified.graph.output[0])

if check:
    onnx_simplified_path = "clip_model_simplified.onnx"
    onnx.save(model_simplified, onnx_simplified_path)
    print(f"Simplified model saved to {onnx_simplified_path}")
else:
    print("⚠️ Simplified model could not be validated — check graph manually.")


session = ort.InferenceSession(onnx_simplified_path)
inputs = {session.get_inputs()[0].name: np.random.randn(*dummy_input_shape).astype(np.float32)}
outputs = session.run(None, inputs)
print("Inference success, output shape:", [o.shape for o in outputs])

In [None]:
onnx_simplified_path = "clip_model_simplified.onnx"
onnx_model = onnx.load(onnx_simplified_path)

# Check input shape
input_shape = onnx_model.graph.input[0].type.tensor_type.shape
print("Input shape:", input_shape)

# Check output shape
output_shape = onnx_model.graph.output[0].type.tensor_type.shape
print("Output shape:", output_shape)

In [None]:
image = preprocess(Image.open("/home/haoyu/projects/model2onnx/cat.jpg")).unsqueeze(0).cpu() # [1, 3, 224, 224]
image_onnx = image.detach().cpu().numpy().astype(np.float32)
print("Input image shape for ONNX model:", image_onnx.shape)

ori_output = None
onnx_output_np = None

# check original model output as numPy
with torch.no_grad():
    original_output = model.encode_image(torch.from_numpy(image_onnx))
    original_output_np = original_output.cpu().numpy()
    print("Original model output shape:", original_output_np.shape)
    ori_output = original_output_np[0]

# output of simplified ONNX model matches original model output
onnx_simplified_path = "clip_model_simplified.onnx"
simplified_model = ort.InferenceSession(onnx_simplified_path)
inputs = {simplified_model.get_inputs()[0].name: image_onnx}
onnx_outputs = simplified_model.run(None, inputs)
onnx_output_np = onnx_outputs[0]

print("ONNX model output shape:", ori_output.shape)
print("simplified_model shape:", onnx_output_np[0].shape)


In [None]:
abs_diff = np.abs(ori_output - onnx_output_np[0])
mean_diff = np.mean(abs_diff)
print("Mean absolute difference between original and ONNX output:", mean_diff)

is_close = np.all(mean_diff < 1e-5)
print("Outputs are almost identical:", is_close)

In [None]:
import onnxruntime
# test onnx inference with batch size 1 speed
dummy_input_shape = (6, 3, 224, 224)
onnx_simplified_path = "clip-visual.onnx"
# Average inference time over 100 runs with batch size 6: 4.10 ms
dummy_input_shape = (4, 3, 224, 224)
onnx_simplified_path = "clip_model_raw.onnx"
# Average inference time over 100 runs with batch size 4: 39.69 ms

session = ort.InferenceSession(onnx_simplified_path)
# set GPU execution provider if available
if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
    session.set_providers(['CUDAExecutionProvider'])

inputs = {session.get_inputs()[0].name: np.random.randn(*dummy_input_shape).astype(np.float32)}
from time import perf_counter

# warm up for 50 runs
for _ in range(50):
    _ = session.run(None, inputs)

# mean 100 runs with random input
num_runs = 100
start_time = perf_counter()
for _ in range(num_runs):
    _ = session.run(None, inputs)
end_time = perf_counter()
avg_time = (end_time - start_time) / num_runs
print(f"Average inference time over {num_runs} runs with batch size {dummy_input_shape[0]}: {avg_time * 1000:.2f} ms")
