-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MuAdam not adjusting lr for output weights #7
Comments
@edwardjhu Thank you edward! You answer solved my confusion. Just for a double check, if I need to implement a custom output layer, the table 8 means that I need to initialize the output weight with std 1 and always divide the output of the layer with fanin, right? |
That's right! |
@zhuzilin I want to fill in more information here that may have been lost in the subcontext. We don't want you to use exactly std=1 and divide the output layer by exactly fanin. You should interpret the 1 as O(1) and fanin as O(fanin). In other words, this table just says that, when you double your fanin, the multiplier on the last layer should be halved, but the initialization should be unchanged. The exact numbers you use for initialization and the multiplier should be tuned from some base model. This discussion applies to all other parameters in the table. Regarding the output layer specifically, we actually recommend you initializing it at 0 if possible (assuming you don't have tricky weight tying btw input/output weights). This should not affect the performance of your model after training, but it will typically improve the transfer quality. You can see section D.2 in the paper for more details. |
Hi, thank you for your great project for hyperparameter tuning!
As our team migrating the mup to other training framework, it occurs to us that the
MuAdam
does not scale the learning rate for output weights as the TP5 paper illustrated:mup/mup/optim.py
Lines 55 to 70 in c9d6700
It seems to us that only the lr of hidden layer (the layer with 2 inf dimensions) is scaled w.r.t fanin, but the output weight is ignored. We wonder if this is intended. Thank you!
The text was updated successfully, but these errors were encountered: