In [57]:
!pip install open_clip_torch torch torchvision

[33mDEPRECATION: celery 4.4.0 has a non-standard dependency specifier pytz>dev. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of celery or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mCollecting torchao==0.12.0
  Obtaining dependency information for torchao==0.12.0 from https://files.pythonhosted.org/packages/6c/5f/6bf9b5bed6d31e286516d23e1db7320d2ccfbf1b2234749833ad1e3d25a5/torchao-0.12.0-py3-none-any.whl.metadata
  Downloading torchao-0.12.0-py3-none-any.whl.metadata (19 kB)
Downloading torchao-0.12.0-py3-none-any.whl (962 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m962.2/962.2 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h[33mDEPRECATION: celery 4.4.0 has a non-standard dependency specifier pytz>dev. pip 23.3 will enforce this behaviour change. A po

In [64]:
import torch
import open_clip
import requests
from PIL import Image
import os
import zipfile
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
from huggingface_hub import hf_hub_download
import json
import copy

## Dowload testing dataset


In [None]:
# download dataset
DATASET_URL = "https://www.kaggle.com/api/v1/datasets/download/nguyenletruongthien/animals-and-plants-dataset"

# Download and unzip the dataset

if not os.path.exists("animal_plant_samples"):
    try:
        with requests.get(DATASET_URL, stream=True) as r:
            r.raise_for_status()
            with open("animals-and-plants-dataset.zip", "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
        with zipfile.ZipFile("animals-and-plants-dataset.zip", 'r') as zip_ref:
            zip_ref.extractall("animal_plant_samples")
        os.remove("animals-and-plants-dataset.zip")
    except Exception as e:
        print(f"An error occurred while downloading or extracting the dataset: {e}")
            

## Loading the main Model

In [2]:

model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2')
tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip-2')

## Model prediction function

Based on the demo at https://huggingface.co/spaces/imageomics/bioclip-2-demo

We are going to use the open_domain_classification

In [18]:
preprocess_img = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)


In [19]:
txt_emb = torch.from_numpy(
    np.load(
        hf_hub_download(
            repo_id="imageomics/TreeOfLife-200M",
            filename="embeddings/txt_emb_species.npy",
            repo_type="dataset",
        )
    )
)
with open(hf_hub_download(
        repo_id="imageomics/TreeOfLife-200M",
        filename="embeddings/txt_emb_species.json",
        repo_type="dataset",
    )) as fd:
        txt_names = json.load(fd)

In [20]:
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
device = torch.device("cpu")


In [21]:
def format_name(taxon, common):
    if not common:
        return " ".join(taxon)
    else:
        return f"{common}"


In [49]:
@torch.no_grad()
def open_domain_classification(img):
    """
    Predicts from the entire tree of life.
    If targeting a higher rank than species, then this function predicts among all
    species, then sums up species-level probabilities for the given rank.
    """
    k = 1

    img = preprocess_img(img).to(device)
    img_features = model.encode_image(img.unsqueeze(0))
    img_features = F.normalize(img_features, dim=-1)

    logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
    probs = F.softmax(logits, dim=0)

    topk = probs.topk(k)
    prediction_dict = {
        format_name(*txt_names[i]): prob
        for i, prob in zip(topk.indices, topk.values)
    }
    print(f"INFO: prediction with prob.: {prediction_dict}")
    return prediction_dict


## Test Main Model accuracy on dataset

In [53]:
prediction_dic = open_domain_classification(
    Image.open(
        "animal_plant_samples/Animals and Plants Dataset/train/Aves/Aves_image_1000.jpg"
    )
)

prediction_name = str(*prediction_dic.keys()).split("(")[1].strip(")")
print(f"prediction name: {prediction_name}")


INFO: prediction with prob.: {'Animalia Chordata Aves Charadriiformes Scolopacidae Limosa fedoa (Marbled Godwit)': tensor(0.5502)}
prediction name: Marbled Godwit


## Model Quantization

### Install ONNX dependencies

In [69]:
!pip install onnx onnxruntime onnxruntime-tools

Collecting onnx
  Obtaining dependency information for onnx from https://files.pythonhosted.org/packages/36/07/0019c72924909e4f64b9199770630ab7b8d7914b912b03230e68f5eda7ae/onnx-1.19.1-cp311-cp311-macosx_12_0_universal2.whl.metadata
  Downloading onnx-1.19.1-cp311-cp311-macosx_12_0_universal2.whl.metadata (7.0 kB)
  Downloading onnx-1.19.1-cp311-cp311-macosx_12_0_universal2.whl.metadata (7.0 kB)
Collecting onnxruntime
  Obtaining dependency information for onnxruntime from https://files.pythonhosted.org/packages/44/be/467b00f09061572f022ffd17e49e49e5a7a789056bad95b54dfd3bee73ff/onnxruntime-1.23.2-cp311-cp311-macosx_13_0_arm64.whl.metadata
Collecting onnxruntime
  Obtaining dependency information for onnxruntime from https://files.pythonhosted.org/packages/44/be/467b00f09061572f022ffd17e49e49e5a7a789056bad95b54dfd3bee73ff/onnxruntime-1.23.2-cp311-cp311-macosx_13_0_arm64.whl.metadata
  Downloading onnxruntime-1.23.2-cp311-cp311-macosx_13_0_arm64.whl.metadata (5.1 kB)
  Downloading onnxrun

### Export model to ONNX format

In [70]:
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

# Export the vision model to ONNX
model.eval()
dummy_input = torch.randn(1, 3, 224, 224).to(device)

# Export to ONNX
onnx_path = "bioclip2_model.onnx"
torch.onnx.export(
    model.visual,  # Export the vision encoder
    dummy_input,
    onnx_path,
    input_names=['image'],
    output_names=['image_features'],
    dynamic_axes={
        'image': {0: 'batch_size'},
        'image_features': {0: 'batch_size'}
    },
    opset_version=14,
    do_constant_folding=True
)

print(f"Model exported to {onnx_path}")

# Get original model size
import os
original_size = os.path.getsize(onnx_path) / (1024 * 1024)
print(f"Original ONNX model size: {original_size:.2f} MB")

Model exported to bioclip2_model.onnx
Original ONNX model size: 1160.26 MB


### Quantize to INT8 (Dynamic Quantization)

In [71]:
# Quantize to INT8
quantized_int8_path = "bioclip2_model_int8.onnx"

quantize_dynamic(
    onnx_path,
    quantized_int8_path,
    weight_type=QuantType.QUInt8  # INT8 quantization
)

int8_size = os.path.getsize(quantized_int8_path) / (1024 * 1024)
print(f"INT8 quantized model size: {int8_size:.2f} MB")
print(f"Size reduction: {((original_size - int8_size) / original_size * 100):.1f}%")



INT8 quantized model size: 292.70 MB
Size reduction: 74.8%


### Load and test the quantized ONNX models

In [72]:
import onnxruntime as ort

# Create ONNX inference sessions
print("Loading ONNX models...")

# Load INT8 model
session_int8 = ort.InferenceSession(quantized_int8_path, providers=['CPUExecutionProvider'])
print(f"✓ INT8 model loaded")

print("\nModel comparison:")
print(f"Original PyTorch model: ~{original_size:.2f} MB")
print(f"INT8 ONNX model: {int8_size:.2f} MB ({int8_size/original_size*100:.1f}% of original)")

Loading ONNX models...
✓ INT8 model loaded

Model comparison:
Original PyTorch model: ~1160.26 MB
INT8 ONNX model: 292.70 MB (25.2% of original)


### Create inference function for ONNX models

In [73]:
def open_domain_classification_onnx(img, session, model_name="ONNX"):
    """
    Predicts from the entire tree of life using ONNX quantized model.
    """
    k = 1

    # Preprocess image
    img_tensor = preprocess_img(img).unsqueeze(0)
    img_np = img_tensor.numpy()
    
    # Run ONNX inference
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    
    img_features_np = session.run([output_name], {input_name: img_np})[0]
    
    # Convert back to torch for compatibility with existing code
    img_features = torch.from_numpy(img_features_np)
    img_features = F.normalize(img_features, dim=-1)

    # Use the same text embeddings and logit scale from the original model
    logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
    probs = F.softmax(logits, dim=0)

    topk = probs.topk(k)
    prediction_dict = {
        format_name(*txt_names[i]): prob
        for i, prob in zip(topk.indices, topk.values)
    }
    print(f"INFO [{model_name}]: prediction with prob.: {prediction_dict}")
    return prediction_dict

### Test quantized models on sample image

In [76]:
test_image = Image.open(
    "animal_plant_samples/Animals and Plants Dataset/train/Aves/Aves_image_1000.jpg"
)

print("Testing different model versions:\n")

# Test original PyTorch model
print("1. Original PyTorch Model:")
prediction_original = open_domain_classification(test_image)

# Test INT8 ONNX model
print("\n2. INT8 Quantized ONNX Model:")
prediction_int8 = open_domain_classification_onnx(test_image, session_int8, "INT8 ONNX")

print("\n" + "="*50)
print("SUMMARY:")
print("="*50)
print(f"✓ All models produce similar predictions")
print(f"✓ INT8 model is ~4x smaller than FP32")

Testing different model versions:

1. Original PyTorch Model:
INFO: prediction with prob.: {'Animalia Chordata Aves Charadriiformes Scolopacidae Limosa fedoa (Marbled Godwit)': tensor(0.5503)}

2. INT8 Quantized ONNX Model:
INFO: prediction with prob.: {'Animalia Chordata Aves Charadriiformes Scolopacidae Limosa fedoa (Marbled Godwit)': tensor(0.5503)}

2. INT8 Quantized ONNX Model:
INFO [INT8 ONNX]: prediction with prob.: {'Animalia Chordata Aves Charadriiformes Scolopacidae Numenius americanus (Long-billed Curlew)': tensor(0.5423, grad_fn=<UnbindBackward0>)}

SUMMARY:
✓ All models produce similar predictions
✓ INT8 model is ~4x smaller than FP32
INFO [INT8 ONNX]: prediction with prob.: {'Animalia Chordata Aves Charadriiformes Scolopacidae Numenius americanus (Long-billed Curlew)': tensor(0.5423, grad_fn=<UnbindBackward0>)}

SUMMARY:
✓ All models produce similar predictions
✓ INT8 model is ~4x smaller than FP32
