-
Notifications
You must be signed in to change notification settings - Fork 13
Models for different embedding dimensions #3
Comments
Hi Ethan 👋 Unfortunately, we do not have these model weights anymore, I apologize for the inconvenience! However, training models with different embedding sizes (as mentioned in the paper) is much faster and computationally efficient than training models from scratch. Quoting Section 4.4:
We have released ViT-L/16 weights (links in README) that remain unchanged in the models you want. Reproducing these experiments will require you to re-initialize the image/text projection layers (two Based on my observation, these models train very quickly — you will get reasonable performance well before 30K iterations since the trainable parameters are low. Moreover, you can afford a large batch size with fewer GPUs due to reduced model size. Let me know if you have further questions! |
Correction: these will be |
I am experimenting with models of different dimensions and used the following code. I expect you to have already downloaded one of the pretrained models, and have a path to a train_config file (e.g. First, let's load the model import torch
from meru.config import LazyConfig, LazyFactory
from meru.utils.checkpointing import CheckpointManager
# get device
device = (
torch.cuda.current_device()
if torch.cuda.is_available()
else torch.device("cpu")
)
# Create the model using training config and load pre-trained weights.
_C_TRAIN = LazyConfig.load(train_config)
model = LazyFactory.build_model(_C_TRAIN, device).eval()
CheckpointManager(model=model).load(checkpoint_path) Now freeze all layers except for those in learnable_params = ['logit_scale', 'curv', 'visual_alpha', 'textual_alpha', 'visual_proj.weight', 'textual_proj.weight']
for name, p in model.named_parameters():
if name not in learnable_params:
p.requires_grad = False After this, start your training! Hope this helps! |
Hi,
Wondering if you could upload trained models for different embedding sizes?
Thanks
The text was updated successfully, but these errors were encountered: