-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hyper-parameters of ViT-B/16 training from scratch #2
Comments
Note that for the published checkpoints we pretrained on As for the hyper parameters:
We used the same cropping code but an image size of 224 (thus 14x14 grid). The model was exactly the same, other than the additional penultimate layer with dimensionality |
Thanks for reply ! |
By the way, what's the top1 accuracy of ViT-B/16 training from scratch on ImageNet with an image size of 224 ? There is a statement in the paper
Is it 77.9% ? Thanks. |
The 79.9% refers to the self-supervised pretraining - see B.1.2. in the appendix for details. The B/16 model pre-trained and fine-tuned on |
That was a typo (now corrected) - it should have said 12.4M examples. See this comment for more details. |
What is the top1 acc of pretraining (without finetuning) on imagenet2012? |
top1 acc (evaluated on 50k holdout from training set) at the end of the pre-training from the original ViT paper was as follows:
Note that we have much more detail about pre-training from scratch in the paper How to train your ViT?..., check out the database in our Colab: For example, to show you the final pre-training top1 accuracy of a variety of models and pre-training settings: import plotly.express as px
px.scatter(
df.drop_duplicates('filename').query('ds=="i1k" and final_val>0.4'),
y='final_val',
x='aug',
color='wd',
symbol='do',
facet_col='name',
facet_col_wrap=4,
) |
Hi, do you use the same hyperparameters in pretraining imagenet1k? |
For Imagenet1k pre-training we used the following hparams different from the hparams used for pre-training on Imagenet21k:
(note that training L/16 and L/32 on i1k can be tricky; you might want to reduce the number of epochs or augment data as described in How to train your ViT? paper) |
Thanks for sharing your code. Can you provide the hyper-parameters (e.g. learning rate, weight decay, optimizer type, training epochs) of ViT-B/16 training from scratch on ImageNet dataset? Many thanks.
The text was updated successfully, but these errors were encountered: