Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch clip model for ONNX compatibility #219

Merged
merged 2 commits into from
Apr 10, 2022

Conversation

chajath
Copy link
Contributor

@chajath chajath commented Feb 18, 2022

Team,

When productionizing AI models like CLIP, it is often useful to be able to export to ONNX so that we can utilize training and serving ecosystem built around ONNX runtime. This PR includes changes that are necessary to make ONNX compilation and runtime working without needing to further patch the code:

With these changes, we can now compile and run ONNX models. For example:

import clip
import torch

m, pre = clip.load("RN50")
npx = m.visual.input_resolution
dummy_image = torch.randn(10, 3, npx, npx)
dummy_texts = clip.tokenize(["quick brown fox", "lorem ipsum"])
m.forward(dummy_image ,dummy_texts) # Original CLIP result (1)

torch.onnx.export(m, (dummy_image, dummy_texts), "clip_resnet.onnx", export_params=True,
  input_names=["IMAGE", "TEXT"],
  output_names=["LOGITS_PER_IMAGE", "LOGITS_PER_TEXT"],
  opset_version=14,
  dynamic_axes={
      "IMAGE": {
          0: "image_batch_size",
      },
      "TEXT": {
          0: "text_batch_size",
      },
      "LOGITS_PER_IMAGE": {
          0: "image_batch_size",
          1: "text_batch_size",
      },
      "LOGITS_PER_TEXT": {
          0: "text_batch_size",
          1: "image_batch_size",
      },
  }
)

# Now run onnxruntime to verify
import onnxruntime as ort

ort_sess = ort.InferenceSession("clip_resnet.onnx")
result=ort_sess.run(["LOGITS_PER_IMAGE", "LOGITS_PER_TEXT"], 
  {"IMAGE": dummy_image.numpy(), "TEXT": dummy_texts.numpy()})
result # verify that result is comparable to (1)

I've locally verified the result. I see that CLIP has a github action to run pytest. I will be happy to further contribute by adding onnx-specific tests to see if the model can be compiled and the resulting model is correct.

I understand that ONNX compatibility might not be the primary goal of CLIP repo, but maintaining compatibility with ONNX will be immensely useful to the ML practitioners out in the wild, so please take a look at the change.

Thanks!

Changes to use INT32 for tokenization, since ONNX doesn't yet support ArgMax(INT64)
Use explicit dimension for norm
@chajath
Copy link
Contributor Author

chajath commented Feb 18, 2022

Regarding the test error, I see that torch 1.7.1 requires indices to be LongTensor, whereas 1.8 and above allows both LongTensor and IntTensor. Do you have suggestions as to how we should workaround? I wonder how long 1.7.1 is going to be supported by CLIP.

@chajath
Copy link
Contributor Author

chajath commented Feb 18, 2022

So I've pushed the change to preserve the behavior when torch version is < 1.8.0. This makes all the tests to pass again in 1.7.1, but 1.7.1 users won't benefit from this fix.

@hanxiao
Copy link

hanxiao commented Apr 10, 2022

For ONNX support, please use https://github.com/jina-ai/clip-as-service/

@chajath
Copy link
Contributor Author

chajath commented Apr 10, 2022

For ONNX support, please use https://github.com/jina-ai/clip-as-service/

Respectfully, I don't think it's appropriate to promote your service in this PR.

@jongwook
Copy link
Collaborator

Thanks for the PR! I think branching based on torch version makes sense.

@jongwook jongwook merged commit 7ef63f2 into openai:main Apr 10, 2022
@chajath chajath deleted the chajath/onnx-compatibility-patches branch April 10, 2022 20:44
@deshwalmahesh
Copy link

Team,

When productionizing AI models like CLIP, it is often useful to be able to export to ONNX so that we can utilize training and serving ecosystem built around ONNX runtime. This PR includes changes that are necessary to make ONNX compilation and runtime working without needing to further patch the code:

With these changes, we can now compile and run ONNX models. For example:

import clip
import torch

m, pre = clip.load("RN50")
npx = m.visual.input_resolution
dummy_image = torch.randn(10, 3, npx, npx)
dummy_texts = clip.tokenize(["quick brown fox", "lorem ipsum"])
m.forward(dummy_image ,dummy_texts) # Original CLIP result (1)

torch.onnx.export(m, (dummy_image, dummy_texts), "clip_resnet.onnx", export_params=True,
  input_names=["IMAGE", "TEXT"],
  output_names=["LOGITS_PER_IMAGE", "LOGITS_PER_TEXT"],
  opset_version=14,
  dynamic_axes={
      "IMAGE": {
          0: "image_batch_size",
      },
      "TEXT": {
          0: "text_batch_size",
      },
      "LOGITS_PER_IMAGE": {
          0: "image_batch_size",
          1: "text_batch_size",
      },
      "LOGITS_PER_TEXT": {
          0: "text_batch_size",
          1: "image_batch_size",
      },
  }
)

# Now run onnxruntime to verify
import onnxruntime as ort

ort_sess = ort.InferenceSession("clip_resnet.onnx")
result=ort_sess.run(["LOGITS_PER_IMAGE", "LOGITS_PER_TEXT"], 
  {"IMAGE": dummy_image.numpy(), "TEXT": dummy_texts.numpy()})
result # verify that result is comparable to (1)

I've locally verified the result. I see that CLIP has a github action to run pytest. I will be happy to further contribute by adding onnx-specific tests to see if the model can be compiled and the resulting model is correct.

I understand that ONNX compatibility might not be the primary goal of CLIP repo, but maintaining compatibility with ONNX will be immensely useful to the ML practitioners out in the wild, so please take a look at the change.

Thanks!

How can you use model.encode_image() instead of model.forward() for the onnx model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants