Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Huge discrepancy between HuggingFace and Timm for ViT and other vision transformers #19305

Closed
Phuoc-Hoan-Le opened this issue Oct 3, 2022 · 7 comments
Assignees
Labels
Core: Modeling Internals of the library; Models.

Comments

@Phuoc-Hoan-Le
Copy link

Phuoc-Hoan-Le commented Oct 3, 2022

Feature request

Differences between HugginFace and Timm implementation of Vision Transformers can be listed as below:
-Missing stochastic depth (https://arxiv.org/abs/2012.12877)
-Using m.weight.data.normal_(mean=0.0, std=0.02) instead of trunc_normal_()
-Missing trunc_normal_() init for the position embedding and cls_token

My DeiT started training properly once I started using the trunc_normal_() init and stochastic depth for my huggingface ViT model. Also, I remove the pruning head functionality and I no longer inherit the HuggingFace ViT model class from the "PretrainedModel" class, but I'm not actually sure if this actually caused my training to work properly.

Motivation

These things could mean the difference between getting Nan or not during training for DeiT using the process from https://arxiv.org/abs/2012.12877

Your contribution

Would love to share my code but I can't. I refer you to read the code (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)

@LysandreJik
Copy link
Member

@alaradirik
Copy link
Contributor

alaradirik commented Oct 5, 2022

@CharlesLeeeee thank you for bringing this up! We are aware of the discrepancy and aim to rectify it soon. We will fix the parameter initialization issue shortly and open a separate PR to add stochastic depth.

cc @LysandreJik @NielsRogge @amyeroberts

@Phuoc-Hoan-Le
Copy link
Author

I believe setting eps in layernorm to 1e-6 rather than 1e-12 is also important.

@rwightman
Copy link

rwightman commented Oct 14, 2022

FWIW there is an issue related to this on the timm side as well huggingface/pytorch-image-models#1477

As per my comments, the init issue should be minor / non consequential as it would not result in a significant difference given that std == .02. I've trained from scratch with much more significantly different inits and the end results aren't too far off.

The layer norm eps is likely an issue though, that was not mentioned on the timm side. For float16, 0 + 1e-12 = 0, not so for 1e-6 or 1e-5, which are defaults for all vision models I'm aware of that use LN. It looks like there are possibly other models that incorrectly use 1e-12 such as convnext? This could cause stability issues at reduced precision and will change the validation results for weights pretrained with 1e-5 or 1e-6. Generally 1e-12 should only be used as eps if you're sticking with float32 (or all uses of that eps are guaranteed to be upcast to float32).

@Phuoc-Hoan-Le
Copy link
Author

Kaiming initialization should be used for the nn.Conv2d rather than .normal_() initialization in the class ViTPreTrainedModel or any class that directly inherits from PretrainedModel. And the biases of the nn.Conv2d in ViT should be initialized the same way as PyTorch. (https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv2d) @LysandreJik @NielsRogge @amyeroberts @alaradirik

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@NielsRogge NielsRogge added the Core: Modeling Internals of the library; Models. label Nov 29, 2022
@github-actions github-actions bot closed this as completed Dec 7, 2022
@NielsRogge
Copy link
Contributor

NielsRogge commented Dec 9, 2022

@CharlesLeeeee you are partially right, it seems that ViT uses PyTorch's default initialization scheme for nn.conv2d, at least in timm. The JAX init however uses a LeCun normal as seen here.

I'm working on this in #19449

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Modeling Internals of the library; Models.
Projects
None yet
Development

No branches or pull requests

5 participants