-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Errors while importing FlaxHybridCLIP checkpoints to FlaxCLIPModel or CLIPModel #14417
Comments
Hey @g8a9 It's not possible to load The hybrid clip model will soon be officially supported in |
Nice to know, thanks! (actually, my final goal is to have access to our two fine-tuned encoders, ViT and BERT, in pytorch) |
The module structure is pretty much similar, so yes! If not I'll share a script to convert the old hybrid clip weights to this new class. |
Hey @g8a9 clip-italian (or any hybrid clip) model can now be loaded using the new from transformers import FlaxVisionTextDualEncoderModel, VisionTextDualEncoderModel
# `logit_scale` can be initialized using `config.logit_scale_init_value` attribute
model = FlaxVisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", logit_scale_init_value=1)
model.save_pretrained("clip-italin")
model_pt = VisionTextDualEncoderModel.from_pretrained("clip-italian", from_flax=True) Let me know if this works for you and if you see any discrepancies in the result. I would like to use clip-Italian to feature this new model class :) |
Do you think you could push the PT checkpoint and the processor (tokenizer/feature-extractor) for |
this worked for me. Thanks for the solution. |
Environment info
transformers
version: 4.12.2Who can help
@patil-suraj @patrickvonplaten
Information
During the last Flax/JAX Community Week we trained a fine-tuned version of CLIP for the Italian language. We used the provided script, so we trained a FlaxHybridCLIP model with Open AI's ViT and
"dbmdz/bert-base-italian-xxl-uncased"
BERT as encoders.Now, I'm trying to use that model with the transformers' official API classes, either FlaxCLIPModel or CLIPModel (my final goal would be to port it to pytorch and publish it to the hub). However, I am having a hard time loading our weights into any of the two.
I tried different workarounds (see below) but none of them seems to be working.
To reproduce
I assume these imports
Steps to reproduce the behavior:
but for both of them, I got inconsistent shapes for the text_projection dense layer (it is expected to be (512,512) but BERT has hidden size 768, so in our checkpoints it is (768,512)).
If I try to ignore the mismatched shapes it seems to be working, but I think that many of the weights from the checkpoint are not used:
config.text_config.hidden_size == 768
and let the shapes match at loading time:In this case, I don't have mismatching sizes but still many weights from our checkpoint are not used.
Expected behavior
This code to run flawlessly:
Thank you in advance!
The text was updated successfully, but these errors were encountered: