Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LAMB: Differences from the paper author's official implementation #107

Closed
binmakeswell opened this issue Nov 24, 2020 · 5 comments · Fixed by #162
Closed

LAMB: Differences from the paper author's official implementation #107

binmakeswell opened this issue Nov 24, 2020 · 5 comments · Fixed by #162
Assignees
Milestone

Comments

@binmakeswell
Copy link

The LAMB implementation of the PyTorch version you released is different from the official version of TensorFlow released by the paper author. According to the official implementation published in the paper, the author's code implementation skips some parameters according to their names() when calculating. But in your implementation, it seems that all parameters are directly involved in the calculation.
For example, exclude_from_weight_decay=["batch_normalization", "LayerNorm", "layer_norm"]
Their implementation:
https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/lamb.py

@frgfm frgfm added question Further information is requested module: optim labels Nov 24, 2020
@frgfm
Copy link
Owner

frgfm commented Nov 24, 2020

Hi @binmakeswell,

Thanks for the issue!

Actually, if I remember correctly, I committed my implementation before there was any official implementation. I usually implement the paper directly, and not reproducing other implementations. In this case, I will check your reference and see how it benefits the training, performance-wise!

If you have a modification suggestion, feel free to open a PR, I'll look into it!

@frgfm frgfm self-assigned this Nov 24, 2020
@binmakeswell
Copy link
Author

Thanks for your reply, by the way, the implementation of LARS seems to have a similar problem. According to the author and TensorFlow official implementation, they also skip some parameters according to their names() when calculating.
Author's implementation: https://people.eecs.berkeley.edu/~youyang/lars_optimizer.py
TensorFlow official implementation:
https://github.com/tensorflow/tpu/blob/5f71c12a020403f863434e96982a840578fdd127/models/official/efficientnet/lars_optimizer.py
https://github.com/tensorflow/tpu/blob/master/models/official/resnet/lars_util.py

@frgfm
Copy link
Owner

frgfm commented Jan 26, 2021

Thanks for the specifics @binmakeswell !

A few things about the above discussion:

  • in your first message, I don't see any part excluding specifically normalization layers. Could you point out the code section you refer to please?
  • If TensorFlow chose to enforce parameter ignoring, I designed this optimizer so that you pass it the parameters you want to optimize at init. So if you think some layers need to keep their current parameter set, you just have to avoid passing them
  • I can confirm that as of today, there are 5 versions of the paper (cf. https://arxiv.org/abs/1904.00962), so I'll check for last updates, but the "official implementation" mentioned over there was not published when the paper was first released, which is the one I implemented. You mentioned that the parameter ignoring is mentioned in the paper, could you point which section please? 🙏
  • Regarding your last message, your first URL seems to be broken, could you let me know if you have another link for this?

Looking forward to improve my implementation thanks to your feedback :D

@binmakeswell
Copy link
Author

Thanks for you reply.
1.excluding specifically normalization layers means "exclude_from_weight_decay" and "exclude_from_layer_adaptation", you can see them in Tensorflow official document and implementation. https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/LAMB
2. The author seems not give too much explanation about this part in paper, but their code shows this.
3. The new author's LARS implementation link here. https://www.comp.nus.edu.sg/~youy/lars_optimizer.py
4. Another LAMB implementation in Tensorflow. https://github.com/fastalgo/imagenet_resnet50_lamb/blob/master/optimization.py

@frgfm
Copy link
Owner

frgfm commented Nov 6, 2021

Hi @binmakeswell 👋

I just opened a PR to have a different weight decay for normalization layers, meaning that the user can specify the WD to 0 for normalization layers. That should covers the modification you mentioned in this issue :)

Let me know if you have any questions!

@frgfm frgfm added type: enhancement New feature or request module: trainer and removed question Further information is requested labels Nov 6, 2021
@frgfm frgfm added this to the 0.2.0 milestone Nov 6, 2021
@frgfm frgfm closed this as completed in #162 Nov 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants