In [12]:
import torch
from PIL import Image
import open_clip

In [8]:
model, train_transform, eval_transform = open_clip.create_model_and_transforms('ViT-B-16-plus-240', pretrained='laion400m_e32')

In [None]:
dir(model.visual)

In [13]:
tokenizer = open_clip.get_tokenizer('ViT-B-16-plus-240')

image = eval_transform(Image.open("../docs/CLIP.png")).unsqueeze(0)
text = tokenizer(["a diagram", "a dog", "a cat"])

In [15]:
torch.onnx.export(
    model.visual,
    image,
    "image_encoder.onnx",  # where to save the model
    opset_version=14,  # the ONNX version to export the model to
    input_names=["image"],  # the model's input names
    output_names=["image_embedding"],  # the model's output names
    dynamic_axes={  # variable length axes
        "image": {0: "batch", 1: "num_channels", 2: "height", 3: "width"},
        "image_embedding": {0: "batch"},
    }
)

In [16]:
cast_dtype = model.transformer.get_cast_dtype()

x = model.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]

x = x + model.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2)  # NLD -> LND


In [21]:
x.shape, model.attn_mask.shape, text.shape

(torch.Size([77, 3, 640]), torch.Size([77, 77]), torch.Size([3, 77]))

In [17]:
torch.onnx.export(
    model.transformer,
    (x, model.attn_mask),
    "text_encoder.onnx",  # where to save the model
    opset_version=14,  # the ONNX version to export the model to
    input_names=["input_ids", "attention_mask"],  # the model's input names
    output_names=["text_embeds"],  # the model's output names
    dynamic_axes={  # variable length axes
        "input_ids": {0: "batch", 1: "sequence"},
        "attention_mask": {0: "batch", 1: "sequence"},
        "text_embeds": {0: "batch"},
    }
)

In [None]:
!benchmark_app -m image_encoder.onnx -shape "image[1,3,240,240]" -api sync

In [None]:
!benchmark_app -m text_encoder.onnx -shape "input_ids[77,1,640],attention_mask[77,77]" -api sync

In [22]:
torch.onnx.export(
    model,
    (image, text),
    "model.onnx",  # where to save the model
    opset_version=14,  # the ONNX version to export the model to
    input_names=["image", "text"],  # the model's input names
    output_names=["image_embedding"],  # the model's output names
    dynamic_axes={  # variable length axes
        "image": {0: "batch", 1: "num_channels", 2: "height", 3: "width"},
        "text": {0: "batch"},
    }
)

  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
