-
Notifications
You must be signed in to change notification settings - Fork 702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[request] Additional model exports (ONNX, CoreML, ...) #167
Labels
enhancement
New feature or request
Comments
This is how I've used import torch
from transformers import Dinov2Model
image_width = 224
image_height = 224
model_size = 'small' # small, base, large, giant
class Wrapper(torch.nn.Module):
def __init__(self, dinov2_model):
super().__init__()
self.dinov2_model = dinov2_model
def forward(self, tensor):
return self.dinov2_model(tensor).last_hidden_state
dummy_input = torch.rand([1, 3, image_height, image_width]).to('cpu')
dinov2_model = Dinov2Model.from_pretrained(f'facebook/dinov2-{model_size}')
model = Wrapper(dinov2_model).to('cpu')
torch.onnx.export(model, dummy_input, f'dinov2-{model_size}.onnx') Once you have the ONNX model (e.g. ovc dinov2-small.onnx --output_model openvino/dinov2-small --compress_to_fp16=True This is how to use it after conversion: # onnx
import onnxruntime as ort
import numpy as np
session = ort.InferenceSession('dinov2-small.onnx')
model_inputs = session.get_inputs()
input_shape = model_inputs[0].shape
input_height = input_shape[2]
input_width = input_shape[3]
dummy_input = np.random.rand(1, 3, input_height, input_width).astype(np.float32)
outputs = session.run(None, { model_inputs[0].name: dummy_input })
classtoken = outputs[0][0][0]
patchtokens = outputs[0][0][1:]
# openvino
from openvino.runtime import Core
import numpy as np
core = Core()
model = core.read_model(model='openvino/dinov2-small.xml')
compiled_model = core.compile_model(model=model, device_name="CPU")
input_height = compiled_model.input(0).shape[2]
input_width = compiled_model.input(0).shape[3]
dummy_input = np.random.rand(1, 3, input_height, input_width).astype(np.float32)
outputs = compiled_model(dummy_input)
classtoken = outputs[0][0][0]
patchtokens = outputs[0][0][1:] |
This was referenced Apr 1, 2024
Before feeding an image to the model, you should preprocess it. I've written this function in order to do that: import cv2 as cv
import numpy as np
# https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/data/transforms.py#L75-L91
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
MEAN = np.array(IMAGENET_DEFAULT_MEAN)
STD = np.array(IMAGENET_DEFAULT_STD)
def preprocess(img):
if isinstance(img, np.ndarray):
# CV image
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
else:
# PIL Image
img = np.array(img)
img = cv.resize(img, (input_width, input_height), interpolation=cv.INTER_CUBIC)
img = np.array(img) / 255.0 # from [0, 255] to [0.0, 1.0]
img = np.transpose(img, (2, 0, 1)) # from shape (H x W x C) to (C x H x W)
img = (img - MEAN[:, None, None]) / STD[:, None, None] # transforms.Normalize(mean=MEAN, std=STD)
img = img.astype(np.float32)
return img |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Related:
The text was updated successfully, but these errors were encountered: