In [None]:
! pip install onnx
! pip install onnxruntime


Collecting onnxruntime
  Downloading onnxruntime-1.22.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.6 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime-1.22.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.5/16.5 MB[0m [31m97.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected pack

In [None]:
import torch
import torch.nn as nn
from transformers import CLIPModel
from onnxruntime.quantization import quantize_dynamic, QuantType

class CLIPClassifier(nn.Module):
    def __init__(self, clip_model, num_classes=2):
        super().__init__()
        self.clip = clip_model
        self.fc = nn.Linear(512 + 512, num_classes)

    def forward(self, input_ids, pixel_values, attention_mask):
        outputs = self.clip(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
        combined = torch.cat([outputs.image_embeds, outputs.text_embeds], dim=1)
        return self.fc(combined)

base_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPClassifier(base_clip, num_classes=2)
state_dict = torch.load("trained_model.pth", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

input_ids = torch.zeros(1, 77, dtype=torch.long)
pixel_values = torch.randn(1, 3, 224, 224)
attention_mask = torch.ones(1, 77, dtype=torch.long)

torch.onnx.export(
    model,
    (input_ids, pixel_values, attention_mask),
    "clip_classifier.onnx",
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=["input_ids", "pixel_values", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids": {0: "batch", 1: "sequence"},
        "pixel_values": {0: "batch"},
        "attention_mask": {0: "batch", 1: "sequence"},
        "logits": {0: "batch"}
    }
)

quantize_dynamic(
    "clip_classifier.onnx",
    "clip_classifier_quantized.onnx",
    weight_type=QuantType.QUInt8
)

print("Done")



Done


In [None]:
import onnxruntime as ort
onnx_path = "/content/clip_classifier_quantized.onnx"  # Adjust path if needed
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
for input in session.get_inputs():
    print(f"Input: {input.name}, Shape: {input.shape}, Type: {input.type}")
for output in session.get_outputs():
    print(f"Output: {output.name}, Shape: {output.shape}, Type: {output.type}")

Input: input_ids, Shape: ['batch', 'sequence'], Type: tensor(int64)
Input: pixel_values, Shape: ['batch', 3, 224, 224], Type: tensor(float)
Input: attention_mask, Shape: ['batch', 'sequence'], Type: tensor(int64)
Output: logits, Shape: ['batch', 2], Type: tensor(float)


In [None]:
! pip freeze

absl-py==1.4.0
accelerate==1.9.0
aiofiles==24.1.0
aiohappyeyeballs==2.6.1
aiohttp==3.12.15
aiosignal==1.4.0
alabaster==1.0.0
albucore==0.0.24
albumentations==2.0.8
ale-py==0.11.2
altair==5.5.0
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.10.0
anywidget==0.9.18
argon2-cffi==25.1.0
argon2-cffi-bindings==25.1.0
array_record==0.7.2
arviz==0.22.0
astropy==7.1.0
astropy-iers-data==0.2025.8.4.0.42.59
astunparse==1.6.3
atpublic==5.1
attrs==25.3.0
audioread==3.0.1
autograd==1.8.0
babel==2.17.0
backcall==0.2.0
backports.tarfile==1.2.0
beautifulsoup4==4.13.4
betterproto==2.0.0b6
bigframes==2.13.0
bigquery-magics==0.10.2
bleach==6.2.0
blinker==1.9.0
blis==1.3.0
blobfile==3.0.0
blosc2==3.6.1
bokeh==3.7.3
Bottleneck==1.4.2
bqplot==0.12.45
branca==0.8.1
Brotli==1.1.0
build==1.3.0
CacheControl==0.14.3
cachetools==5.5.2
catalogue==2.0.10
certifi==2025.8.3
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.2
chex==0.1.90
clarabel==0.11.1
click==8.2.1
cloudpathlib==0.21.1
cloudpickle==3