Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft]Add onnx export script #129

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions dinov2/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def interpolate_pos_encoding(self, x, w, h):

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
scale_factor=(float(w0 / math.sqrt(N)), float(h0 / math.sqrt(N))),
mode="bicubic",
)

Expand Down Expand Up @@ -303,46 +303,62 @@ def init_weights_vit_timm(module: nn.Module, name: str = ""):
nn.init.zeros_(module.bias)


def vit_small(patch_size=16, **kwargs):
def vit_small(
patch_size: int = 16,
attn_class: nn.Module = MemEffAttention,
**kwargs
):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
block_fn=partial(Block, attn_class=attn_class),
**kwargs,
)
return model


def vit_base(patch_size=16, **kwargs):
def vit_base(
patch_size: int = 16,
attn_class: nn.Module = MemEffAttention,
**kwargs
):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
block_fn=partial(Block, attn_class=attn_class),
**kwargs,
)
return model


def vit_large(patch_size=16, **kwargs):
def vit_large(
patch_size: int = 16,
attn_class: nn.Module = MemEffAttention,
**kwargs
):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
block_fn=partial(Block, attn_class=attn_class),
**kwargs,
)
return model


def vit_giant2(patch_size=16, **kwargs):
def vit_giant2(
patch_size: int = 16,
attn_class: nn.Module = MemEffAttention,
**kwargs
):
"""
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
"""
Expand Down
31 changes: 22 additions & 9 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
import torch.nn as nn

from dinov2.layers.attention import Attention

dependencies = ["torch"]

Expand Down Expand Up @@ -51,32 +51,45 @@ def _make_dinov2_model(
return model


def dinov2_vits14(*, pretrained: bool = True, **kwargs):
def dinov2_vits14(*, pretrained: bool = True, for_onnx: bool = False, **kwargs):
"""
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, **kwargs)
if not for_onnx:
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, **kwargs)
else:
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, attn_class=Attention, **kwargs)


def dinov2_vitb14(*, pretrained: bool = True, **kwargs):
def dinov2_vitb14(*, pretrained: bool = True, for_onnx: bool = False, **kwargs):
"""
DINOv2 ViT-B/14 model pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, **kwargs)
if not for_onnx:
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, **kwargs)
else:
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, attn_class=Attention, **kwargs)



def dinov2_vitl14(*, pretrained: bool = True, **kwargs):
def dinov2_vitl14(*, pretrained: bool = True, for_onnx: bool = False, **kwargs):
"""
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, **kwargs)
if not for_onnx:
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, **kwargs)
else:
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, attn_class=Attention, **kwargs)


def dinov2_vitg14(*, pretrained: bool = True, **kwargs):
def dinov2_vitg14(*, pretrained: bool = True, for_onnx: bool = False, **kwargs):
"""
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, **kwargs)
if not for_onnx:
return _make_dinov2_model(arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, **kwargs)
else:
return _make_dinov2_model(arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, attn_class=Attention, **kwargs)


def _make_dinov2_linear_head(
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ xformers==0.0.18
submitit
--extra-index-url https://pypi.nvidia.com
cuml-cu11
onnx==1.14.0
52 changes: 52 additions & 0 deletions scripts/convert-to-onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""DINOV2 model converter to onnx."""
import torch
import argparse
import os
import sys
from pathlib import Path
current_path = Path(__file__).resolve()
parent_path = current_path.parent.parent.as_posix()
sys.path.insert(0, parent_path)
import hubconf


class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, tensor):
ff = self.model(tensor)
return ff

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="dinov2_vits14", help="dinov2 model name")
parser.add_argument(
"--image_height", type=int, default=280, help="input image height, must be a multiple of patch_size"
)
parser.add_argument(
"--image_width", type=int, default=280, help="input image height, must be a multiple of patch_size"
)
parser.add_argument(
"--patch_size", type=int, default=14, help="dinov2 model patch size, default is 16"
)
args = parser.parse_args()


if __name__ == "__main__":

assert args.image_height % args.patch_size == 0, f"image height must be multiple of {args.patch_size}, but got {args.image_height}"
assert args.image_width % args.patch_size == 0, f"image width must be multiple of {args.patch_size}, but got {args.image_height}"

model = Wrapper(hubconf.dinov2_vits14(for_onnx=True)).to("cpu")
model.eval()

dummy_input = torch.rand([1, 3, args.image_height, args.image_width]).to("cpu")
dummy_output = model(dummy_input)

torch.onnx.export(
model,
dummy_input,
args.model_name + ".onnx",
input_names = ["input"]
)