Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hyper-parameters of ViT-B/16 training from scratch #2

Closed
liuyuyuil opened this issue Oct 26, 2020 · 9 comments
Closed

Hyper-parameters of ViT-B/16 training from scratch #2

liuyuyuil opened this issue Oct 26, 2020 · 9 comments

Comments

@liuyuyuil
Copy link

Thanks for sharing your code. Can you provide the hyper-parameters (e.g. learning rate, weight decay, optimizer type, training epochs) of ViT-B/16 training from scratch on ImageNet dataset? Many thanks.

@andsteing
Copy link
Collaborator

andsteing commented Oct 26, 2020

Note that for the published checkpoints we pretrained on imagenet21k (see README), using 102.4M 12.4M examples for training.

As for the hyper parameters:

batch_size=4096
lr.base=1e-3
lr.decay_type=linear
lr.linear_end=1e-5
lr.warmup_steps=10_000
dropout_rate=0.1
num_epochs=90
weight_decay=0.03
optimizer=Adam
representation_size=768

We used the same cropping code but an image size of 224 (thus 14x14 grid).

The model was exactly the same, other than the additional penultimate layer with dimensionality representation_size. The final classification layer's bias weights were initialized with -10.

@liuyuyuil
Copy link
Author

Thanks for reply !

@liuyuyuil
Copy link
Author

liuyuyuil commented Oct 26, 2020

By the way, what's the top1 accuracy of ViT-B/16 training from scratch on ImageNet with an image size of 224 ? There is a statement in the paper

'With self-supervised pre-training, our smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant improvement of 2% to training from scratch'

Is it 77.9% ? Thanks.

@andsteing
Copy link
Collaborator

The 79.9% refers to the self-supervised pretraining - see B.1.2. in the appendix for details. The B/16 model pre-trained and fine-tuned on imagenet2012 achieves 77.9% (see table 5 in the appendix).

@andsteing
Copy link
Collaborator

That was a typo (now corrected) - it should have said 12.4M examples. See this comment for more details.

@cissoidx
Copy link

The 79.9% refers to the self-supervised pretraining - see B.1.2. in the appendix for details. The B/16 model pre-trained and fine-tuned on imagenet2012 achieves 77.9% (see table 5 in the appendix).

What is the top1 acc of pretraining (without finetuning) on imagenet2012?

@andsteing
Copy link
Collaborator

andsteing commented Aug 18, 2021

top1 acc (evaluated on 50k holdout from training set) at the end of the pre-training from the original ViT paper was as follows:

name val_acc
ViT-B/32 i1k 69.19%
ViT-B/16 i1k 74.79%
ViT-L/32 i1k 66.90%
ViT-L/16 i1k 72.59%

Note that we have much more detail about pre-training from scratch in the paper How to train your ViT?..., check out the database in our Colab:

https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb

For example, to show you the final pre-training top1 accuracy of a variety of models and pre-training settings:

import plotly.express as px

px.scatter(
    df.drop_duplicates('filename').query('ds=="i1k" and final_val>0.4'),
    y='final_val',
    x='aug',
    color='wd',
    symbol='do',
    facet_col='name',
    facet_col_wrap=4,
)

@cissoidx
Copy link

Note that for the published checkpoints we pretrained on imagenet21k (see README), using 102.4M 12.4M examples for training.

As for the hyper parameters:

batch_size=4096
lr.base=1e-3
lr.decay_type=linear
lr.linear_end=1e-5
lr.warmup_steps=10_000
dropout_rate=0.1
num_epochs=90
weight_decay=0.03
optimizer=Adam
representation_size=768

We used the same cropping code but an image size of 224 (thus 14x14 grid).

The model was exactly the same, other than the additional penultimate layer with dimensionality representation_size. The final classification layer's bias weights were initialized with -10.

Hi, do you use the same hyperparameters in pretraining imagenet1k?

@andsteing
Copy link
Collaborator

For Imagenet1k pre-training we used the following hparams different from the hparams used for pre-training on Imagenet21k:

grad_clip_norm=1.0
lr.base=3e-3
lr.decay_type=cosine
dropout_rate=0.1  # B/16, B/32
dropout_rate=0.2  # L/16, L/32
num_epochs=300  # B/16, B/32
weight_decay=0.3

(note that training L/16 and L/32 on i1k can be tricky; you might want to reduce the number of epochs or augment data as described in How to train your ViT? paper)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants