Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CvT] Tensorflow implementation #18597

Merged
merged 15 commits into from
Oct 11, 2022
Merged

[CvT] Tensorflow implementation #18597

merged 15 commits into from
Oct 11, 2022

Conversation

mathieujouffroy
Copy link
Contributor

@mathieujouffroy mathieujouffroy commented Aug 12, 2022

What does this PR do?

This PR adds the Cvt model implementation in Tensorflow.
This includes the base model and the model with an image classification head on top.

TODO

  • Write the fundamental components (Convolutional Token Embeddings & Convolutional Transformer Block)
  • Write base model & image classification model
  • Modify related utilities
  • Write relevant tests (in test suite)
  • Preview Tensorflow documentation for Cvt

Before submitting

Questions

  • In the configuration file of the model CVT, layer_norm_eps is initialized at 1e-12.
    However, it seems that in the original implementation, the authors use epsilon=1e-5.
    Moreover, the Cvt model in pytorch (HuggingFace), does not seem to use the configuration layer_norm_eps=1e-12 for layer normalization throughout the model, instead using the default epsilon=1e-5.
    What is the use of layer_norm_eps in the configuration file (of the Cvt model) ?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 12, 2022

The documentation is not available anymore as the PR was closed or merged.

@mathieujouffroy mathieujouffroy changed the title Tf cvt [CvT] Tensorflow implementation Aug 12, 2022
@mathieujouffroy mathieujouffroy marked this pull request as draft August 16, 2022 15:47
@mathieujouffroy mathieujouffroy marked this pull request as ready for review August 30, 2022 19:43
@LysandreJik
Copy link
Member

Thanks for your PR @mathieujouffroy! Let me ping @amyeroberts for review :)

@mathieujouffroy
Copy link
Contributor Author

You're welcome. Cool thanks, should I create an Issue ?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the model and for such a nice PR ❤️ Just a few nits and one comment regarding a permutation.

We recently had the first community contributor open a PR to add weights on the hub (see PR comment here on steps). Once we've had two other 👍 we should be good to add the model weights, run the slow tests, and then merge.

src/transformers/models/cvt/modeling_tf_cvt.py Outdated Show resolved Hide resolved
src/transformers/models/cvt/modeling_tf_cvt.py Outdated Show resolved Hide resolved
src/transformers/models/cvt/modeling_tf_cvt.py Outdated Show resolved Hide resolved
src/transformers/models/cvt/modeling_tf_cvt.py Outdated Show resolved Hide resolved
src/transformers/models/cvt/modeling_tf_cvt.py Outdated Show resolved Hide resolved
src/transformers/models/cvt/modeling_tf_cvt.py Outdated Show resolved Hide resolved
src/transformers/models/cvt/modeling_tf_cvt.py Outdated Show resolved Hide resolved
@amyeroberts amyeroberts requested a review from gante August 31, 2022 12:46
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the addition of CvT 🔥

I've added a few comments in the code, the most concerning one being the comment about the output shapes.

After the issues raised by me and @amyeroberts, we can finalize the review process by:

  1. tagging a core maintainer for a third review;
  2. opening PRs for the TF weights (I will share the instructions on how to do it :) )

src/transformers/modeling_tf_pytorch_utils.py Outdated Show resolved Hide resolved
src/transformers/models/cvt/modeling_tf_cvt.py Outdated Show resolved Hide resolved
src/transformers/models/cvt/modeling_tf_cvt.py Outdated Show resolved Hide resolved
src/transformers/models/cvt/modeling_tf_cvt.py Outdated Show resolved Hide resolved
src/transformers/models/cvt/modeling_tf_cvt.py Outdated Show resolved Hide resolved
@mathieujouffroy
Copy link
Contributor Author

Thanks a lot for both of your reviews 🙏 !
I've corrected the issues :)
Although, I kept using shape_list instead of tf.shape throughout the implementation of the model as tf.shape was breaking things while running the tests (see comment above).
Should I follow the instructions in this PR comment to upload to weights ?

@gante
Copy link
Member

gante commented Sep 2, 2022

@mathieujouffroy awesome, seems like we are ready to move on to the next stage. I'm adding @sgugger as the last reviewer.

Meanwhile, you can open the PR to the TF model weights on the hub as follows:

  1. Make sure you have the latest version of the hub installed (pip install huggingface_hub -U) and that you are logged in to HF with a write token (huggingface-cli login)
  2. Run transformers-cli pt-to-tf --model-name foo/bar from this branch :D
  3. In the Hub PR, tag @joaogante, @nielsr, @sgugger

@gante gante requested a review from sgugger September 2, 2022 11:45
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your PR! Good to merge once the weights have been added!

@mathieujouffroy
Copy link
Contributor Author

@mathieujouffroy awesome, seems like we are ready to move on to the next stage. I'm adding @sgugger as the last reviewer.

Meanwhile, you can open the PR to the TF model weights on the hub as follows:

  1. Make sure you have the latest version of the hub installed (pip install huggingface_hub -U) and that you are logged in to HF with a write token (huggingface-cli login)
  2. Run transformers-cli pt-to-tf --model-name foo/bar from this branch :D
  3. In the Hub PR, tag @joaogante, @nielsr, @sgugger

I am getting an error when using transformers-cli pt-to-tf --model-name microsoft/cvt-13 :

File "/Users/MathieuJouffroy/transformers/src/transformers/commands/pt_to_tf.py", line 307, in run
    + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_error])
ValueError: The cross-loaded TensorFlow model has different outputs, something went wrong!

List of maximum output differences above the threshold (5e-05):
logits: 1.190e-04

List of maximum hidden layer differences above the threshold (5e-05):
hidden_states[2]: 1.227e-02

It seems that both the max_crossload_output_diff and the max_crossload_hidden_diff are bigger than the self._max_error (5e-5).
Respectively I have max_crossload_output_diff = 0.00011897087 (1.190e-04) and max_crossload_hidden_diff = 0.012268066 (1.227e-02).

I am trying to figure out how to correct this error (WIP).

@gante
Copy link
Member

gante commented Sep 2, 2022

@mathieujouffroy ~1e-2 is quite large -- does this happen exclusively on microsoft/cvt-13, or across all CvT models?

@mathieujouffroy
Copy link
Contributor Author

@mathieujouffroy ~1e-2 is quite large -- does this happen exclusively on microsoft/cvt-13, or across all CvT models?

Yess, unfortunately it happens across all CvT models.
When inspecting the difference between the hidden states of the pytorch model and the hidden states of the tensorflow model, I can see that the difference increases throughout the model (with the number of layers).
The CvT model is composed of 3 stages of encoder block, with their respective number of layers being 1, 2 and 10. In the last stage, the difference between the torch model's hidden state and the tensorflow model's hidden state increases from ~e-5 at layer[0] to ~1e-2 at layer[9].
For microsoft/cvt-13 :

CvtLayer output vs TFCvtLayer output

diff pt-tf stage[0]/layer[0]: 1.77919864654541e-05

diff pt-tf stage[1]/layer[0]: 2.4199485778808594e-05
diff pt-tf stage[1]/layer[1]: 3.9249658584594727e-05

diff pt-tf stage[2]/layer[0]: 3.0934810638427734e-05
diff pt-tf stage[2]/layer[1]: 0.000102996826171875
diff pt-tf stage[2]/layer[2]: 0.0004825592041015625
diff pt-tf stage[2]/layer[3]: 0.0009307861328125
diff pt-tf stage[2]/layer[4]: 0.001621246337890625
diff pt-tf stage[2]/layer[5]: 0.0032196044921875
diff pt-tf stage[2]/layer[6]: 0.0064239501953125
diff pt-tf stage[2]/layer[7]: 0.0091705322265625
diff pt-tf stage[2]/layer[8]: 0.012481689453125
diff pt-tf stage[2]/layer[9]: 0.01226806640625

Hidden Differences:
hidden_states[0]:1.77919864654541e-05
hidden_states[1]:3.9249658584594727e-05
hidden_states[2]:0.01226806640625

Output Differences:
logits:0.00011897087097167969

I can't seem to correct this issue. I was wondering if this was due to floating points operations.
Do you have any advice ? 🙏

@gante
Copy link
Member

gante commented Sep 9, 2022

@mattchurgin in these cases, a deep dive has to be done -- place a pair of breakpoint() in the layer where the problems start, one in each framework, and see which operation causes the divergence. Then, confirm that the TF operation/layer is parametrized correctly and, if it is, one has to dig even deeper :D

@mathieujouffroy
Copy link
Contributor Author

mathieujouffroy commented Sep 26, 2022

Hello @gante, sorry for the late response.
I've done a deep dive into both frameworks. It seems that the Batch Normalization is responsible for the divergence. The 2 residual connections further increase the divergence throughout the model. However, I have parameterized tf.keras.layers.BatchNormalization accordingly to the default parameters of pytorch (epsilon=1e-5 and momentum=0.1). I have also set both models in inference mode when testing.

Is this divergence due to the momentum definition of Batch Normalization being different in tensorflow than in pytorch ?

When removing the Batch Normalization layers from both frameworks, the difference in the output tensors and the hidden states is greatly reduced. I get a max_crossload_output_diff of ~e-6 and a max_crossload_hidden_diff of ~e-4 for all Cvt models. However, the max_crossload_hidden_diff is still higher than 5e-5 (I have ~e-4). The 2 residual connections are responsible for this difference.

I'm a bit confused. Therefore I've inspected the ViT model (google/vit-base-patch16-224) which also has 2 residual connections. There is also a divergence in the hidden states between the tensorflow implementation and the pytorch implemention. This difference also increases throughout the layers (with the residual connections), until it reaches a max_crossload_hidden_diff of ~2e-2 at layer 12.

Is this behaviour normal/acceptable ?

@gante
Copy link
Member

gante commented Sep 28, 2022

@mathieujouffroy That's a great in-depth exploration!

Previously we didn't have these checks in place, so it is possible that issues like the one you're seeing slipped through the cracks. It's not positive at all to have such large mismatches (it implies that TF users will have a poorer experience). I've had in my plans to go back and double-check the most popular models with the recently introduced checks, and you've just raised the priority of the task with your message :)

I think @amyeroberts has seen similar TF/PT mismatches due to the Batch Normalization layer. @amyeroberts do you mind pitching in?

@amyeroberts
Copy link
Collaborator

@mathieujouffroy Thanks for all the work digging into this 🕵️

As momentum is set for both the pytorch and TF models, I believe their behaviour (outputs and moving stats updates) should be the same during both inference and training, given the same weights and params.

@gante @mathieujouffroy Yes, I had similar issues with the TF ResNet port (a weights PR for reference). Like this model, the batch norm layer introduced differences which then got amplified through the forward pass. @ydshieh did some excellent detective work, and found that matching all of the parameters and inputs to produce an equivalent TF and PyTorch layer would still produce outputs with a difference on the order of 1e-7 (enough to start causing problems 😭)

Ultimately, we decided to add the weights as the difference between the logits was small ~1e-5. I think the ~1e-4 absolute differences in this case are acceptable for adding the weights. @sgugger Is this OK?

@sgugger
Copy link
Collaborator

sgugger commented Sep 28, 2022

Yes, as long as it stays in the range of 1e-4, we can accept the difference between frameworks.

@gante
Copy link
Member

gante commented Sep 28, 2022

Thank you for pitching in @amyeroberts :D

@mathieujouffroy feel free to use --max-error 1e-4 (or slightly higher) in the pt-to-tf CLI to ignore those errors and push the weights!

@mathieujouffroy
Copy link
Contributor Author

Hi @amyeroberts, thanks for your mention !! I've added the PR regarding the pytorch model.

@gante following your recommendation I've added the weights on the hub 😊
As @amyeroberts had pointed out, I'll need to remove from_pt in the testing file once the weights are added.

@gante
Copy link
Member

gante commented Oct 11, 2022

@mathieujouffroy weights merged 🙌

@mathieujouffroy
Copy link
Contributor Author

mathieujouffroy commented Oct 11, 2022

@mathieujouffroy weights merged 🙌

Cool thanks @gante 😊 !
I'll update the testing file & run the slow tests locally .

@gante
Copy link
Member

gante commented Oct 11, 2022

@mathieujouffroy off-topic: are you working with transformers as part of École 42? I've been at the school once (like 5+ years ago) and I had a friend who participated -- I really liked the concept!

@mathieujouffroy
Copy link
Contributor Author

mathieujouffroy commented Oct 11, 2022

@mathieujouffroy off-topic: are you working with transformers as part of École 42? I've been at the school once (like 5+ years ago) and I had a friend who participated -- I really liked the concept!

@gante Yess I was working with transformers 🤗 on my last project (computer vision) at École 42. The project was in partnership with Hectar, an agricultural campus. I was pretty excited to try out the vision transformers 😊.
I've also used transformers at 42 for my NLP projects and in my internship (in NLP).
I think 42 is a very good training (I've just finished) 🚀 : project-based & peer to peer pedagogy !

@gante
Copy link
Member

gante commented Oct 11, 2022

All seems ready, merging as soon as CI turns green.

@mathieujouffroy on behalf of TF users, thank you for making the ecosystem richer 🧡

@gante gante merged commit 5ca131f into huggingface:main Oct 11, 2022
@mathieujouffroy
Copy link
Contributor Author

@gante @amyeroberts thanks a lot for your help and feedbacks !! 💛
It was really interesting and cool to do this PR (1st in an open source project) and to get it merge 😊

@mathieujouffroy mathieujouffroy deleted the tf-cvt branch October 11, 2022 17:45
ajsanjoaquin pushed a commit to ajsanjoaquin/transformers that referenced this pull request Oct 12, 2022
* implemented TFCvtModel and TFCvtForImageClassification and modified relevant files, added an exception in convert_tf_weight_name_to_pt_weight_name, added quick testing file to compare with pytorch model

* added docstring + testing file in transformers testing suite

* added test in testing file, modified docs to pass repo-consistency, passed formatting test

* refactoring + passing all test

* small refacto, removing unwanted comments

* improved testing config

* corrected import error

* modified acces to pretrained model archive list, to pass tf_test

* corrected import structure in init files

* modified testing for keras_fit with cpu

* correcting PR issues + Refactoring

* Refactoring : improving readability and reducing the number of permutations

* corrected momentum value + cls_token initialization

* removed from_pt as weights were added to the hub

* Update tests/models/cvt/test_modeling_tf_cvt.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Oct 18, 2022
* implemented TFCvtModel and TFCvtForImageClassification and modified relevant files, added an exception in convert_tf_weight_name_to_pt_weight_name, added quick testing file to compare with pytorch model

* added docstring + testing file in transformers testing suite

* added test in testing file, modified docs to pass repo-consistency, passed formatting test

* refactoring + passing all test

* small refacto, removing unwanted comments

* improved testing config

* corrected import error

* modified acces to pretrained model archive list, to pass tf_test

* corrected import structure in init files

* modified testing for keras_fit with cpu

* correcting PR issues + Refactoring

* Refactoring : improving readability and reducing the number of permutations

* corrected momentum value + cls_token initialization

* removed from_pt as weights were added to the hub

* Update tests/models/cvt/test_modeling_tf_cvt.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
kashif pushed a commit to kashif/transformers that referenced this pull request Oct 21, 2022
* implemented TFCvtModel and TFCvtForImageClassification and modified relevant files, added an exception in convert_tf_weight_name_to_pt_weight_name, added quick testing file to compare with pytorch model

* added docstring + testing file in transformers testing suite

* added test in testing file, modified docs to pass repo-consistency, passed formatting test

* refactoring + passing all test

* small refacto, removing unwanted comments

* improved testing config

* corrected import error

* modified acces to pretrained model archive list, to pass tf_test

* corrected import structure in init files

* modified testing for keras_fit with cpu

* correcting PR issues + Refactoring

* Refactoring : improving readability and reducing the number of permutations

* corrected momentum value + cls_token initialization

* removed from_pt as weights were added to the hub

* Update tests/models/cvt/test_modeling_tf_cvt.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants