Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[request] Additional model exports (ONNX, CoreML, ...) #167

Open
patricklabatut opened this issue Aug 23, 2023 · 2 comments
Open

[request] Additional model exports (ONNX, CoreML, ...) #167

patricklabatut opened this issue Aug 23, 2023 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@patricklabatut
Copy link
Contributor

patricklabatut commented Aug 23, 2023

@barbolo
Copy link

barbolo commented Apr 1, 2024

This is how I've used transformers to export dinov2 outputs with class token + patch tokens for ONNX and OpenVINO.

import torch
from transformers import Dinov2Model

image_width = 224
image_height = 224
model_size = 'small' # small, base, large, giant

class Wrapper(torch.nn.Module):
    def __init__(self, dinov2_model):
        super().__init__()
        self.dinov2_model = dinov2_model
    def forward(self, tensor):
        return self.dinov2_model(tensor).last_hidden_state

dummy_input = torch.rand([1, 3, image_height, image_width]).to('cpu')

dinov2_model = Dinov2Model.from_pretrained(f'facebook/dinov2-{model_size}')
model = Wrapper(dinov2_model).to('cpu')

torch.onnx.export(model, dummy_input, f'dinov2-{model_size}.onnx')

Once you have the ONNX model (e.g. dinov2-small.onnx), you might convert it to OpenVINO with fp16 using the ovc CLI:

ovc dinov2-small.onnx --output_model openvino/dinov2-small --compress_to_fp16=True

This is how to use it after conversion:

# onnx
import onnxruntime as ort
import numpy as np
session = ort.InferenceSession('dinov2-small.onnx')
model_inputs = session.get_inputs()
input_shape = model_inputs[0].shape
input_height = input_shape[2]
input_width = input_shape[3]
dummy_input = np.random.rand(1, 3, input_height, input_width).astype(np.float32)
outputs = session.run(None, { model_inputs[0].name: dummy_input })
classtoken = outputs[0][0][0]
patchtokens = outputs[0][0][1:]

# openvino
from openvino.runtime import Core
import numpy as np
core = Core()
model = core.read_model(model='openvino/dinov2-small.xml')
compiled_model = core.compile_model(model=model, device_name="CPU")
input_height = compiled_model.input(0).shape[2]
input_width = compiled_model.input(0).shape[3]
dummy_input = np.random.rand(1, 3, input_height, input_width).astype(np.float32)
outputs = compiled_model(dummy_input)
classtoken = outputs[0][0][0]
patchtokens = outputs[0][0][1:]

@barbolo
Copy link

barbolo commented Apr 2, 2024

Before feeding an image to the model, you should preprocess it. I've written this function in order to do that:

import cv2 as cv
import numpy as np

# https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/data/transforms.py#L75-L91
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
MEAN = np.array(IMAGENET_DEFAULT_MEAN)
STD = np.array(IMAGENET_DEFAULT_STD)
def preprocess(img):
    if isinstance(img, np.ndarray):
        # CV image
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    else:
        # PIL Image
        img = np.array(img)
    img = cv.resize(img, (input_width, input_height), interpolation=cv.INTER_CUBIC)
    img = np.array(img) / 255.0 # from [0, 255] to [0.0, 1.0]
    img = np.transpose(img, (2, 0, 1)) # from shape (H x W x C) to (C x H x W)
    img = (img - MEAN[:, None, None]) / STD[:, None, None] # transforms.Normalize(mean=MEAN, std=STD)
    img = img.astype(np.float32)
    return img

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants