# CLIP 转 ONNX

## 1. 环境配置

In [1]:
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

In [2]:
# !pip install -U onnx onnxruntime onnxruntime-tools onnxruntime-gpu

In [3]:
# !pip list | grep torch torchvision onnx onnxruntime onnxruntime-tools
!pip list | findstr /i "torch torchvision onnx onnxruntime onnxruntime-tools"

onnx                      1.17.0
onnxruntime               1.20.1
onnxruntime-gpu           1.20.1
onnxruntime-tools         1.7.0
torch                     2.5.1+cu124
torchaudio                2.5.1+cu124
torchvision               0.20.1+cu124


In [4]:
!nvidia-smi

Thu Nov 28 00:41:28 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.94                 Driver Version: 560.94         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4070 ...  WDDM  |   00000000:01:00.0  On |                  N/A |
| N/A   43C    P8              4W /  119W |    1199MiB /   8188MiB |     35%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [5]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Sep_12_02:55:00_Pacific_Daylight_Time_2024
Cuda compilation tools, release 12.6, V12.6.77
Build cuda_12.6.r12.6/compiler.34841621_0


## 2. 生成图文 Embedding

In [6]:
import torch
import clip
import onnx
import onnxruntime as ort
import numpy as np

import utils

from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torch.nn.functional import cosine_similarity

In [7]:
MODEL_PATH = 'workspace'
DATA_PATH = 'data'
TEXT_ONNX_PATH = 'clip_text.onnx'
IMG_ONNX_PATH = 'clip_image.onnx'
VIT_ONNX_PATH = 'clip_vit.onnx'

img_path = utils.gen_abspath(directory=DATA_PATH, rel_path='cat.JPG')
model_path = utils.gen_abspath(directory='./', rel_path=MODEL_PATH)

text_onnx_path = utils.gen_abspath(directory=MODEL_PATH, rel_path=TEXT_ONNX_PATH)
img_onnx_path = utils.gen_abspath(directory=MODEL_PATH, rel_path=IMG_ONNX_PATH)
vit_onnx_path = utils.gen_abspath(directory=MODEL_PATH, rel_path=VIT_ONNX_PATH)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'device: {device}')
print(f'torch.version.cuda: {torch.version.cuda}')

device: cuda
torch.version.cuda: 12.4


In [9]:
# 加载模型和处理器
model_name = 'openai/clip-vit-base-patch32'
model = CLIPModel.from_pretrained(model_name, cache_dir=model_path)
processor = CLIPProcessor.from_pretrained(model_name, cache_dir=model_path)

model.eval()
_ = model.to(device)

In [10]:
# 文本 embedding
def get_text_embedding(texts, device):
    inputs = processor(text=texts, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        embeddings = model.get_text_features(**inputs)
    text_embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)
    return text_embeddings

# 图片 embedding
def get_image_embedding(image_path, device):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        embeddings = model.get_image_features(**inputs)
    image_embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)
    return image_embeddings

In [11]:
texts = ["A photo of a cat", "A picture of a dog"]
text_embeddings = get_text_embedding(texts, device=device)
text_embeddings.shape

torch.Size([2, 512])

In [12]:
image_embedding = get_image_embedding(img_path, device=device)
image_embedding.shape

torch.Size([1, 512])

In [13]:
similarity = cosine_similarity(text_embeddings, image_embedding)
similarity

tensor([0.2863, 0.2313], device='cuda:0')

In [14]:
similarity = (text_embeddings @ image_embedding.T)
similarity

tensor([[0.2863],
        [0.2313]], device='cuda:0')

## 3. 转换为 ONNX

In [15]:
def export_text_model_to_onnx(model, output_path, device):
    # 导出文本模型部分
    dummy_input = torch.randint(0, 77, (1, 77)).to(device)  # 假设序列长度为 77

    torch.onnx.export(
        model.text_model,
        dummy_input,
        output_path,
        input_names=["input_ids"],
        output_names=["text_features"],
        dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}, "text_features": {0: "batch_size"}},
        opset_version=20
    )

def export_image_model_to_onnx(model, output_path, device):
    # 导出视觉模型部分
    dummy_input = torch.randn(1, 3, 224, 224).to(device)   # 假设图片大小为 (3, 224, 224)

    torch.onnx.export(
        model.vision_model,
        dummy_input,
        output_path,
        input_names=["pixel_values"],
        output_names=["image_features"],
        dynamic_axes={"pixel_values": {0: "batch_size"}, "image_features": {0: "batch_size"}},
        opset_version=20
    )

In [17]:
# 导出文本模型
export_text_model_to_onnx(model,
                          output_path=text_onnx_path,
                          device=device)

In [18]:
# 导出图片模型
export_image_model_to_onnx(model,
                           output_path=img_onnx_path,
                           device=device)

## 4. 验证 ONNX 模型

### 1）文本模型

In [19]:
# 加载 ONNX 模型
text_session = ort.InferenceSession(text_onnx_path)

# 假设输入文本序列
dummy_input = np.random.randint(0, 77, (1, 77)).astype(np.int64)

# 推理
outputs = text_session.run(None, {"input_ids": dummy_input})
outputs[0].shape

(1, 77, 512)

In [20]:
output = outputs[0]

# 平均池化
ap_text_embedding = np.mean(output, axis=1)
print(f'Average Pooling: {ap_text_embedding.shape}')

# 取最后一个位置的特征向量
lp_text_embedding = output[:, 0, :]
print(f'Last Position: {lp_text_embedding.shape}')

Average Pooling: (1, 512)
Last Position: (1, 512)


### 2）图片模型

In [21]:
# 加载 ONNX 模型
image_session = ort.InferenceSession(img_onnx_path)

# 假设输入图片
dummy_image = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 推理
outputs = image_session.run(None, {"pixel_values": dummy_image})
outputs[0].shape

(1, 50, 768)

In [22]:
# 加载 ONNX 模型
img_onnx_model = onnx.load(img_onnx_path)

# 检查输出节点
for output in img_onnx_model.graph.output:
    print(f"Output name: {output.name}, Shape: {output.type.tensor_type.shape}")

Output name: image_features, Shape: dim {
  dim_param: "batch_size"
}
dim {
  dim_param: "Addimage_features_dim_1"
}
dim {
  dim_param: "Addimage_features_dim_2"
}

Output name: 1274, Shape: dim {
  dim_param: "batch_size"
}
dim {
  dim_param: "Addimage_features_dim_2"
}



In [23]:
# 检查模型的有效性
onnx.checker.check_model(img_onnx_model)

## 5. 直接计算图文相似度

In [24]:
m, pre = clip.load("ViT-B/32", device=device, download_root=model_path)
npx = m.visual.input_resolution
dummy_image = torch.randn(10, 3, npx, npx).to(device)
dummy_texts = clip.tokenize(["quick brown fox", "lorem ipsum"]).to(device)
m.forward(dummy_image, dummy_texts)

(tensor([[19.7188, 24.3281],
         [19.9688, 24.9375],
         [19.6406, 24.6875],
         [19.8438, 24.7344],
         [19.9688, 24.7031],
         [19.7812, 24.9062],
         [20.1719, 24.6719],
         [19.4375, 24.7500],
         [19.8125, 24.5000],
         [19.5781, 24.8906]], device='cuda:0', dtype=torch.float16,
        grad_fn=<MmBackward0>),
 tensor([[19.7188, 19.9688, 19.6406, 19.8438, 19.9688, 19.7812, 20.1719, 19.4375,
          19.8125, 19.5781],
         [24.3281, 24.9375, 24.6875, 24.7344, 24.7031, 24.9062, 24.6719, 24.7500,
          24.5000, 24.8906]], device='cuda:0', dtype=torch.float16,
        grad_fn=<TBackward0>))

In [25]:
torch.onnx.export(
    m,
    (dummy_image, dummy_texts),
    vit_onnx_path,
    export_params=True,
    input_names=["IMAGE", "TEXT"],
    output_names=["LOGITS_PER_IMAGE", "LOGITS_PER_TEXT"],
    opset_version=20,
    dynamic_axes={
        "IMAGE": {
            0: "image_batch_size",
        },
        "TEXT": {
            0: "text_batch_size",
        },
        "LOGITS_PER_IMAGE": {
            0: "image_batch_size",
            1: "text_batch_size",
        },
        "LOGITS_PER_TEXT": {
            0: "text_batch_size",
            1: "image_batch_size",
        },
    }
)

ort_sess = ort.InferenceSession(vit_onnx_path)
result = ort_sess.run(
    ["LOGITS_PER_IMAGE", "LOGITS_PER_TEXT"],
    {"IMAGE": dummy_image.cpu().numpy(), "TEXT": dummy_texts.cpu().numpy()})

result

[array([[19.72, 24.33],
        [19.95, 24.94],
        [19.66, 24.7 ],
        [19.86, 24.77],
        [19.97, 24.72],
        [19.78, 24.92],
        [20.16, 24.67],
        [19.44, 24.75],
        [19.81, 24.52],
        [19.6 , 24.9 ]], dtype=float16),
 array([[19.72, 19.95, 19.66, 19.86, 19.97, 19.78, 20.16, 19.44, 19.81,
         19.6 ],
        [24.33, 24.94, 24.7 , 24.77, 24.72, 24.92, 24.67, 24.75, 24.52,
         24.9 ]], dtype=float16)]

参考：

- [triton-inference-server/tutorials](https://github.com/triton-inference-server/tutorials)
- [preprocessing](https://github.com/triton-inference-server/python_backend/blob/main/examples/preprocessing/README.md)
- [Model_Ensembles](https://github.com/triton-inference-server/tutorials/blob/main/Conceptual_Guide/Part_5-Model_Ensembles/README.md)