-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Modify Muon optimizer #21885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Modify Muon optimizer #21885
Conversation
Summary of ChangesHello @pass-lin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refines the Muon optimizer by enhancing its flexibility and robustness. It introduces a dedicated weight decay parameter for Adam-optimized variables, updates the learning rate adjustment logic to a more widely recognized 'Moonlight version' for improved stability, and clarifies the optimizer's application scope to strictly 2D variables, delegating other dimensions to Adam. These changes aim to make Muon a more versatile and performant optimizer, especially in diverse model architectures. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
@fchollet @hertschuh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several valuable improvements to the Muon optimizer, including adding a separate weight decay for Adam, aligning with the Moonlight implementation, and restricting Muon updates to 2D variables for better stability. The changes are well-motivated and correctly implemented. I've provided a few suggestions to enhance docstring clarity and improve code readability, in line with the repository's style guide.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21885 +/- ##
==========================================
- Coverage 82.63% 82.56% -0.07%
==========================================
Files 577 578 +1
Lines 59415 59804 +389
Branches 9313 9389 +76
==========================================
+ Hits 49097 49377 +280
- Misses 7913 8004 +91
- Partials 2405 2423 +18
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tested this end-to-end with model.fit to see that it trains as expected?
Is there a way to compare with the original implementation?
| if len(variable.shape) != 2: | ||
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not following this change. In the Moonlight implementation the criteria for using Muon is that the ndim >= 2: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L296
So for AdamW, the criteria would be ndim < 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not following this change. In the Moonlight implementation the criteria for using Muon is that the : https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L296`ndim >= 2`
So for AdamW, the criteria would be .
ndim < 2
The optimization target of Muon is matrices. In the 3D case, reshaping into matrices is necessary for effective optimization. However, this involves too many assumptions, and introducing it would only unnecessarily increase complexity. In fact, Muon never considered the case of CNNs. It was designed with only 1D-Transformer scenarios in mind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not following this change. In the Moonlight implementation the criteria for using Muon is that the : https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L296`ndim >= 2`
So for AdamW, the criteria would be .
ndim < 2
In the original implementation of MoonLight, they could ensure that the optimization target is a Transformer model based on PyTorch. However, in the Keras implementation, we cannot guarantee this. For example, in a typical case with the PyTorch backend, if we mix keras.layers.Dense and torch.nn.Linear, then the optimization targets would simultaneously include variables of shape [d_out, d_in] and [d_in, d_out].
Similarly, if the optimization target is a 3D CNN model, the parameter meanings for the CNN model differ between the "channels_last" and "channels_first" formats. We lack reasonable assumptions to perform reshaping in such cases.
The Muon optimizer in Keras should be a general-purpose optimizer, and a general-purpose optimizer should not rely on too many assumptions. Therefore, we can only use the most conservative approach: we do not optimize anything other than matrices.
This is also the reason why we do not use the Keller Jordan Version. The Keller Jordan Version assumes that the optimized matrix must be either [d_out, d_in] or [d_in, d_out], while MoonLight does not require such assumptions.
keras/src/optimizers/muon.py
Outdated
| self.assign_sub( | ||
| variable, | ||
| lr | ||
| * self.zeropower_via_newtonschulz5(g, self.ns_steps) | ||
| * max(1, shape[0] / shape[1]) ** 0.5, | ||
| self.lr_adjust(lr * update), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: this can be on a single line now.
| continue | ||
| wd = ops.cast(weight_decay_value, variable.dtype) | ||
| lr = ops.cast(self.learning_rate, variable.dtype) | ||
| variable.assign(variable - variable * wd * lr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use self.assign(variable, variable - variable * wd * lr)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use
self.assign(variable, variable - variable * wd * lr)
keras/keras/src/optimizers/base_optimizer.py
Line 967 in 846a297
| variable.assign(variable - variable * wd * lr) |
Here, I maintain consistency with the existing weight decay implementation.
This is one of my pre-training tasks. Each epoch consists of 512 steps. It can be observed that MoonLight Muon not only trains normally, but also has a more stable loss compared to Adam. |
|
@hertschuh A Guide to the Muon Optimizer: Quick Start and Key DetailsDuring this period, I believe many readers have already come across news about the Muon optimizer. Muon was first proposed around last October by Keller Jordan on Twitter, which means it has been around for just over a year. However, in this single year, Muon has already endured the training scrutiny of models with billions, hundreds of billions, and even trillions of parameters, which is enough to prove that it is a highly competitive optimizer. Muon is now built into training frameworks like Torch and Keras, and even large-scale frameworks like Megatron are gradually starting to support it, indicating that it has gained general acceptance in the industry. However, for readers who are only familiar with Adam, how to quickly and effectively switch to Muon may still be a confusing matter. Therefore, this article attempts to provide a quick start guide. Brief IntroductionThe official proposer of Muon is Keller Jordan, who currently works at OpenAI. As mentioned earlier, Muon was first published on Twitter, and to this day, the author has only written a blog post, ["Muon: An optimizer for hidden layers in neural networks,"](https://www.youtube.com/watch?v=dQw4w9WgXcQ) instead of a paper. The author's view is that "whether or not it is written as a Paper has nothing to do with whether the optimizer is effective [original quote]." Muon is an optimizer specifically customized for matrix parameters. There are some related works with similar characteristics, such as Shampoo, and the earlier Stochastic Spectral Descent, etc. Many works can be associated with Muon to a greater or lesser extent, but none completely cover Muon, so the author considers Muon to be a completely new work. In China, the earliest article to popularize Muon was probably the author's blog post, Muon Optimizer Appreciation: An Essential Leap from Vector to Matrix, and the first large-scale model to verify Muon was our Moonlight, released in February, whose proposed Moonlight version of Muon was used in the subsequent trillion-parameter K2 model. Following K2, GLM-4.5 also used this Muon variant. As Jeremy Bernstein, one of Muon's authors, stated in his blog post Deriving Muon, for the author of this article, Muon's uniqueness lies in the fact that it can be derived from more fundamental optimization principles and is effective in practice. In contrast, although Adam is also very effective, it is more like a heuristic solution. Four VersionsThis article does not intend to introduce the mathematical details or the implementation of Muon but focuses primarily on the technical details and precautions for switching from Adam to Muon. As mentioned, Muon is specifically for matrix parameter optimization and uses a non-element-wise update rule, which can be confusing for new users. Furthermore, as far as the author knows, there are currently at least four slightly different versions of Muon, and this multi-version phenomenon exacerbates the confusion. If users don't understand the details, they might get poor results by setting the wrong hyperparameters (especially the learning rate). The following section will clarify these details. First, for a matrix Naive Version (朴素版): Keller Jordan Version (Keller Jordan 版): MuP Version (MuP 版): Moonlight Version (Moonlight 版): To enable Nesterov momentum, replace The only difference between the four versions is the scaling factor before The Two DimensionsHere, we must pay attention to an important detail: the "Keller Jordan Version" and the "MuP Version" are sensitive to the order of Therefore, to implement the "Keller Jordan Version" of Muon, the scaling factor for Torch's Linear layer should be If you write your own model, you need to judge carefully based on your own code. Of course, if you find figuring this out too troublesome, you can consider using the "Moonlight Version," whose scaling factor is symmetric with respect to Hyperparameter SettingsAfter clarifying Let's look at the "Moonlight Version" first. Its scaling factor is derived by aligning with the Update RMS of Adam. Simply put, the "Moonlight Version" Muon aligns with Adam's update magnitude, so the simplest way to migrate from Adam is: don't change anything; just use the same Next, consider the remaining three versions. We know that mainstream models usually have a Substituting Does this mean the "Moonlight Version" is easier to use? The "Moonlight Version" indeed has good practical results, but saying it's better is evaluating it from the perspective of Adam. The advantage of the "MuP Version" or "Keller Jordan Version" is learning rate transferability, meaning that a learning rate tuned on a small model often works well when applied directly to a large model. Other ParametersIf Muon only handles matrix parameters, what about the other parameters? For example, the Bias term of linear layers or the Let me first correct myself: Muon does not just handle matrix parameters; Muon only handles "matrix parameters of densely-input linear layers." If the reader finds this confusing, just remember that the matrix parameters of the Embedding layer and the final classification layer (including the GPT's LM Head) should not use Muon, or the effect will be noticeably worse. For these matrix parameters that cannot use Muon, as well as 1D, 3D, and higher-dimensional parameters, if the reader doesn't want to overthink it, they can just use Adam. Muon implementations are basically a mix with Adam, allowing users to select certain layers to use Adam. If the reader is willing to tinker, then 3D or 4D parameters, such as those in convolution layers, can also use Muon. Taking Conv2D as an example, the kernel shape is usually Similarly, the Never stop tinkering! Expected ResultsFinally, if the user follows the instructions above, sets everything up correctly, and starts running, then they can begin praying for the arrival of the goddess of luck. What kind of result should we expect? If there are no anomalies such as gradient explosion, Muon will usually be slightly better than Adam in most cases. Of course, it is not ruled out that Muon may be slightly worse in some situations, but in any case, the difference between them will not be very large. If one is significantly better than the other, it might be necessary to reconsider if there was a problem with the settings on either side. However, this is not absolute. For example, under certain extreme settings, it is possible for Muon to be much better than Adam, with Adam failing to converge no matter how it is tuned. Conversely, the opposite is rarely encountered. In summary, I wish you good luck. If you encounter any interesting phenomena, you are welcome to discuss and analyze them together. Source Link: https://kexue.fm/archives/11416 |

In this PR, we have introduced three improvements to Muon:
1.In the Muon optimizer, we often designate a subset of variables to be optimized with Adam. However, since different optimizers should not be assumed to have the same weight decay parameter, we addressed this by adding an adam_weight_decay parameter.
2.The current implementation of Muon mainly references the KellerJordan version. However, the Moonlight version is now widely recognized as superior. Compared to the KellerJordan version, the Moonlight version adjusts the learning rate from max(d_out/d_in, 1)**0.5 to max(d_out, d_in) * rate. The KellerJordan version assumes that the second dimension is the output dimension and the first dimension is the input dimension. As a general-purpose optimizer, we should not make such assumptions.
Additionally, the Moonlight version allows Muon and Adam to maintain the same weight decay and learning rate. We have added an rms_rate parameter to enable this feature, with a default value of 0.2. This parameter can be disabled by setting it to None. We have also adjusted some default parameters based on the Moonlight version.
3.When we initially submitted Muon optimizer, our understanding of Muon was not deep enough. As our research progressed, we discovered that Muon was designed with the assumption that the model is a Transformer. For 3D weights, it is necessary to assume that one dimension is d_in, and the other dimensions are reshaped to d_out. However, unlike the 2D case, the 3D scenario does not always have a clear distinction between d_in and d_out. Therefore, out of caution, we only use the Adam optimizer for cases other than 2D.