Skip to content

Conversation

@pass-lin
Copy link
Contributor

@pass-lin pass-lin commented Dec 2, 2025

In this PR, we have introduced three improvements to Muon:

1.In the Muon optimizer, we often designate a subset of variables to be optimized with Adam. However, since different optimizers should not be assumed to have the same weight decay parameter, we addressed this by adding an adam_weight_decay parameter.

2.The current implementation of Muon mainly references the KellerJordan version. However, the Moonlight version is now widely recognized as superior. Compared to the KellerJordan version, the Moonlight version adjusts the learning rate from max(d_out/d_in, 1)**0.5 to max(d_out, d_in) * rate. The KellerJordan version assumes that the second dimension is the output dimension and the first dimension is the input dimension. As a general-purpose optimizer, we should not make such assumptions.
Additionally, the Moonlight version allows Muon and Adam to maintain the same weight decay and learning rate. We have added an rms_rate parameter to enable this feature, with a default value of 0.2. This parameter can be disabled by setting it to None. We have also adjusted some default parameters based on the Moonlight version.

3.When we initially submitted Muon optimizer, our understanding of Muon was not deep enough. As our research progressed, we discovered that Muon was designed with the assumption that the model is a Transformer. For 3D weights, it is necessary to assume that one dimension is d_in, and the other dimensions are reshaped to d_out. However, unlike the 2D case, the 3D scenario does not always have a clear distinction between d_in and d_out. Therefore, out of caution, we only use the Adam optimizer for cases other than 2D.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @pass-lin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refines the Muon optimizer by enhancing its flexibility and robustness. It introduces a dedicated weight decay parameter for Adam-optimized variables, updates the learning rate adjustment logic to a more widely recognized 'Moonlight version' for improved stability, and clarifies the optimizer's application scope to strictly 2D variables, delegating other dimensions to Adam. These changes aim to make Muon a more versatile and performant optimizer, especially in diverse model architectures.

Highlights

  • Separate Adam Weight Decay: Introduced an adam_weight_decay parameter to allow for distinct weight decay settings when the Adam optimizer is used for a subset of variables, addressing the issue of different optimizers potentially requiring different weight decay values.
  • Muon Learning Rate Adjustment (Moonlight Version): Updated the Muon optimizer's learning rate adjustment mechanism to align with the 'Moonlight version', which modifies the scaling factor from max(d_out/d_in, 1)**0.5 to max(d_out, d_in) * rate. A new rms_rate parameter (default 0.2) was added to enable this feature, allowing Muon and Adam to maintain consistent learning rates and weight decay, and can be disabled by setting it to None.
  • Strict 2D Variable Optimization for Muon: Refined the application scope of the Muon optimizer to strictly apply only to 2D variables. For any variables that are not 2D (e.g., 0D, 1D, 3D, or higher dimensions), the Adam optimizer will now be used, based on the understanding that Muon was originally designed with Transformer models and 2D weights in mind.
  • Default Parameter Adjustments: Adjusted several default parameters within the Muon optimizer, including adam_lr_ratio (from 0.1 to 1), weight_decay (from 0.1 to 0.004), and ns_steps (from 6 to 5), to better align with the 'Moonlight version' and improved practices.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@pass-lin
Copy link
Contributor Author

pass-lin commented Dec 2, 2025

@fchollet @hertschuh
In this PR, we removed the TF bug fix from #21859, because #21797 provides a more accurate fix. In this PR, we focused on improving the muon algorithm, as the current Keras muon optimizer still has many issues. These issues stem from our initial insufficient understanding of muon.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several valuable improvements to the Muon optimizer, including adding a separate weight decay for Adam, aligning with the Moonlight implementation, and restricting Muon updates to 2D variables for better stability. The changes are well-motivated and correctly implemented. I've provided a few suggestions to enhance docstring clarity and improve code readability, in line with the repository's style guide.

@codecov-commenter
Copy link

codecov-commenter commented Dec 2, 2025

Codecov Report

❌ Patch coverage is 80.95238% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.56%. Comparing base (a0004ee) to head (4e4f375).
⚠️ Report is 36 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/optimizers/muon.py 80.95% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21885      +/-   ##
==========================================
- Coverage   82.63%   82.56%   -0.07%     
==========================================
  Files         577      578       +1     
  Lines       59415    59804     +389     
  Branches     9313     9389      +76     
==========================================
+ Hits        49097    49377     +280     
- Misses       7913     8004      +91     
- Partials     2405     2423      +18     
Flag Coverage Δ
keras 82.37% <80.95%> (-0.08%) ⬇️
keras-jax 62.89% <80.95%> (-0.44%) ⬇️
keras-numpy 57.46% <80.95%> (-0.11%) ⬇️
keras-openvino 34.33% <9.52%> (+0.02%) ⬆️
keras-tensorflow 64.42% <80.95%> (+0.30%) ⬆️
keras-torch 63.60% <80.95%> (-0.03%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@pass-lin pass-lin mentioned this pull request Dec 2, 2025
@pass-lin pass-lin changed the title Modify Muon optimizermodify muon. Modify Muon optimizer Dec 2, 2025
Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested this end-to-end with model.fit to see that it trains as expected?

Is there a way to compare with the original implementation?

Comment on lines +144 to 145
if len(variable.shape) != 2:
return True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following this change. In the Moonlight implementation the criteria for using Muon is that the ndim >= 2: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L296

So for AdamW, the criteria would be ndim < 2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following this change. In the Moonlight implementation the criteria for using Muon is that the : https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L296`ndim >= 2`

So for AdamW, the criteria would be .ndim < 2

The optimization target of Muon is matrices. In the 3D case, reshaping into matrices is necessary for effective optimization. However, this involves too many assumptions, and introducing it would only unnecessarily increase complexity. In fact, Muon never considered the case of CNNs. It was designed with only 1D-Transformer scenarios in mind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following this change. In the Moonlight implementation the criteria for using Muon is that the : https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L296`ndim >= 2`

So for AdamW, the criteria would be .ndim < 2

In the original implementation of MoonLight, they could ensure that the optimization target is a Transformer model based on PyTorch. However, in the Keras implementation, we cannot guarantee this. For example, in a typical case with the PyTorch backend, if we mix keras.layers.Dense and torch.nn.Linear, then the optimization targets would simultaneously include variables of shape [d_out, d_in] and [d_in, d_out].

Similarly, if the optimization target is a 3D CNN model, the parameter meanings for the CNN model differ between the "channels_last" and "channels_first" formats. We lack reasonable assumptions to perform reshaping in such cases.

The Muon optimizer in Keras should be a general-purpose optimizer, and a general-purpose optimizer should not rely on too many assumptions. Therefore, we can only use the most conservative approach: we do not optimize anything other than matrices.

This is also the reason why we do not use the Keller Jordan Version. The Keller Jordan Version assumes that the optimized matrix must be either [d_out, d_in] or [d_in, d_out], while MoonLight does not require such assumptions.

Comment on lines 203 to 206
self.assign_sub(
variable,
lr
* self.zeropower_via_newtonschulz5(g, self.ns_steps)
* max(1, shape[0] / shape[1]) ** 0.5,
self.lr_adjust(lr * update),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: this can be on a single line now.

continue
wd = ops.cast(weight_decay_value, variable.dtype)
lr = ops.cast(self.learning_rate, variable.dtype)
variable.assign(variable - variable * wd * lr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use self.assign(variable, variable - variable * wd * lr)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use self.assign(variable, variable - variable * wd * lr)

variable.assign(variable - variable * wd * lr)

Here, I maintain consistency with the existing weight decay implementation.

@pass-lin
Copy link
Contributor Author

pass-lin commented Dec 3, 2025

image

This is one of my pre-training tasks. Each epoch consists of 512 steps. It can be observed that MoonLight Muon not only trains normally, but also has a more stable loss compared to Adam.
@hertschuh

@pass-lin
Copy link
Contributor Author

pass-lin commented Dec 3, 2025

@hertschuh
I'd like to recommend a blog to you. It is written by Su Jianlin, the author of MoonLight and RoPE. However, considering that this is a Chinese blog, I have translated it into English for you by gemini

A Guide to the Muon Optimizer: Quick Start and Key Details

During this period, I believe many readers have already come across news about the Muon optimizer. Muon was first proposed around last October by Keller Jordan on Twitter, which means it has been around for just over a year. However, in this single year, Muon has already endured the training scrutiny of models with billions, hundreds of billions, and even trillions of parameters, which is enough to prove that it is a highly competitive optimizer.

Muon is now built into training frameworks like Torch and Keras, and even large-scale frameworks like Megatron are gradually starting to support it, indicating that it has gained general acceptance in the industry. However, for readers who are only familiar with Adam, how to quickly and effectively switch to Muon may still be a confusing matter. Therefore, this article attempts to provide a quick start guide.

Brief Introduction

The official proposer of Muon is Keller Jordan, who currently works at OpenAI. As mentioned earlier, Muon was first published on Twitter, and to this day, the author has only written a blog post, ["Muon: An optimizer for hidden layers in neural networks,"](https://www.youtube.com/watch?v=dQw4w9WgXcQ) instead of a paper. The author's view is that "whether or not it is written as a Paper has nothing to do with whether the optimizer is effective [original quote]."

Muon is an optimizer specifically customized for matrix parameters. There are some related works with similar characteristics, such as Shampoo, and the earlier Stochastic Spectral Descent, etc. Many works can be associated with Muon to a greater or lesser extent, but none completely cover Muon, so the author considers Muon to be a completely new work.

In China, the earliest article to popularize Muon was probably the author's blog post, Muon Optimizer Appreciation: An Essential Leap from Vector to Matrix, and the first large-scale model to verify Muon was our Moonlight, released in February, whose proposed Moonlight version of Muon was used in the subsequent trillion-parameter K2 model. Following K2, GLM-4.5 also used this Muon variant.

As Jeremy Bernstein, one of Muon's authors, stated in his blog post Deriving Muon, for the author of this article, Muon's uniqueness lies in the fact that it can be derived from more fundamental optimization principles and is effective in practice. In contrast, although Adam is also very effective, it is more like a heuristic solution.

Four Versions

This article does not intend to introduce the mathematical details or the implementation of Muon but focuses primarily on the technical details and precautions for switching from Adam to Muon. As mentioned, Muon is specifically for matrix parameter optimization and uses a non-element-wise update rule, which can be confusing for new users.

Furthermore, as far as the author knows, there are currently at least four slightly different versions of Muon, and this multi-version phenomenon exacerbates the confusion. If users don't understand the details, they might get poor results by setting the wrong hyperparameters (especially the learning rate). The following section will clarify these details. First, for a matrix $W \in \mathbb{R}^{d_{in} \times d_{out}}$ with gradient $G$, the four Muon variants are:

$$ M_t = \beta M_{t-1} + G_t $$

Naive Version (朴素版):

$$W_t = W_{t-1} - \eta_t (\mathrm{msign}(M_t) + \lambda W_{t-1})$$

Keller Jordan Version (Keller Jordan 版):

$$W_t = W_{t-1} - \eta_t \left( \sqrt{\max(1, d_{out}/d_{in})} \mathrm{msign}(M_t) + \lambda W_{t-1} \right)$$

MuP Version (MuP 版):

$$W_t = W_{t-1} - \eta_t \left( \sqrt{d_{out}/d_{in}} \mathrm{msign}(M_t) + \lambda W_{t-1} \right)$$

Moonlight Version (Moonlight 版):

$$W_t = W_{t-1} - \eta_t \left( 0.2 \times \sqrt{\max(d_{out}, d_{in})} \mathrm{msign}(M_t) + \lambda W_{t-1} \right)$$

To enable Nesterov momentum, replace $\text{msign}(M_t)$ with $\text{msign}(\beta M_t + G_t)$. The $\text{msign}$ operation is usually named zeropower_via_newtonschulz in implementation, but ordinary users do not need to worry about the specific implementation details.

The only difference between the four versions is the scaling factor before $\text{msign}$. The "Keller Jordan Version" and the "MuP Version" are largely similar, while the "Moonlight Version" is slightly more unique. Keras has only implemented the "Keller Jordan Version," while Torch has implemented the "Keller Jordan Version" and the "Moonlight Version." The Naive Version seems to be relatively uncommon. The author of this article frequently uses their self-written "MuP Version."

The Two Dimensions

Here, we must pay attention to an important detail: the "Keller Jordan Version" and the "MuP Version" are sensitive to the order of $d_{in}$ and $d_{out}$. Therefore, the first step is to clarify the meaning of $d_{in}$ and $d_{out}$; it is not the case that the first dimension of the matrix is always $d_{in}$ and the second dimension is $d_{out}$.

$d_{in}$ and $d_{out}$ refer to the input and output dimensions of the linear layer, respectively. Determining which is $d_{in}$ and which is $d_{out}$ depends on the specific implementation of the linear layer. For example, in Keras's Dense layer, the implementation is $xW$, so the matrix $W$'s first dimension is $d_{in}$ and the second is $d_{out}$. However, Torch's Linear layer implements $xW^\top$, so the matrix $W$'s second dimension is $d_{in}$ and the first is $d_{out}$.

Therefore, to implement the "Keller Jordan Version" of Muon, the scaling factor for Torch's Linear layer should be $\sqrt{\max(1, W\text{.shape}[0]/W\text{.shape}[1])}$, while for Keras, it should be $\sqrt{\max(1, W\text{.shape}[1]/W\text{.shape}[0])}$. Consequently, the current Keras Muon implementation is actually incorrect because it copied Torch's scaling factor implementation.

If you write your own model, you need to judge carefully based on your own code. Of course, if you find figuring this out too troublesome, you can consider using the "Moonlight Version," whose scaling factor is symmetric with respect to $d_{in}$ and $d_{out}$.

Hyperparameter Settings

After clarifying $d_{in}$ and $d_{out}$, the remaining issue is how to set the learning rate $\eta_t$ and the weight decay coefficient $\lambda$. The assumption here is that the user already has experience tuning Adam, has achieved good results with Adam, and wants to quickly switch to Muon for a trial run.

Let's look at the "Moonlight Version" first. Its scaling factor is derived by aligning with the Update RMS of Adam. Simply put, the "Moonlight Version" Muon aligns with Adam's update magnitude, so the simplest way to migrate from Adam is: don't change anything; just use the same $\eta_t$ and $\lambda$ as Adam.

Next, consider the remaining three versions. We know that mainstream models usually have a $hidden_size$ (denoted as $d$), and the shape of the model's matrices mostly does not deviate significantly from $d \times d$. We can approximate by setting $d_{in} = d_{out} = d$. In this case, these three versions are identical and lack the factor of $0.2 \sqrt{d}$ compared to the "Moonlight Version." Since the "Moonlight Version" aligns with Adam's update magnitude without changing hyperparameters, the learning rate for these three versions should be scaled up by a factor of $0.2 \sqrt{d}$ to align with Adam's update magnitude. Correspondingly, $\lambda$ should be divided by $0.2 \sqrt{d}$.

Substituting $d=1024, 2048, 4096$, the results for $0.2 \sqrt{d}$ are approximately $6.4, 9, 12.8$. If you can't remember $0.2 \sqrt{d}$, you can simply remember that if we use the other three versions of Muon, we should generally multiply the Adam learning rate by $10$ to use as the Muon learning rate. If you directly plug the Adam learning rate into Muon, you will get the conclusion that Muon is far inferior to Adam due to underfitting. As far as the author knows, some negative reviews of Muon stem from this.

Does this mean the "Moonlight Version" is easier to use? The "Moonlight Version" indeed has good practical results, but saying it's better is evaluating it from the perspective of Adam. The advantage of the "MuP Version" or "Keller Jordan Version" is learning rate transferability, meaning that a learning rate tuned on a small model often works well when applied directly to a large model.

Other Parameters

If Muon only handles matrix parameters, what about the other parameters? For example, the Bias term of linear layers or the $\gamma$ term of RMSNorm are 1-dimensional parameters; and convolution layers might have 3-dimensional or 4-dimensional array parameters.

Let me first correct myself: Muon does not just handle matrix parameters; Muon only handles "matrix parameters of densely-input linear layers." If the reader finds this confusing, just remember that the matrix parameters of the Embedding layer and the final classification layer (including the GPT's LM Head) should not use Muon, or the effect will be noticeably worse. For these matrix parameters that cannot use Muon, as well as 1D, 3D, and higher-dimensional parameters, if the reader doesn't want to overthink it, they can just use Adam. Muon implementations are basically a mix with Adam, allowing users to select certain layers to use Adam.

If the reader is willing to tinker, then 3D or 4D parameters, such as those in convolution layers, can also use Muon. Taking Conv2D as an example, the kernel shape is usually $(w, h, d_{in}, d_{out})$. Its equivalent implementation is to flatten the $(w, h, d_{in})$ Patch input into a $w \times h \times d_{in}$ vector, and then reshape the kernel to $(w \times h \times d_{in}, d_{out})$ before performing matrix multiplication. So, to use Muon, you must first reshape the momentum to $(w \times h \times d_{in}, d_{out})$, calculate $\text{msign}$, and then reshape it back for the update.

Similarly, the $\gamma$ parameter of RMSNorm can be viewed as multiplication by a diagonal matrix. By treating its momentum as a diagonal matrix, $\text{msign}$ can also be calculated, which is equivalent to SignSGDM. The Embedding layer can be viewed as multiple $(1, d)$ matrices for $\text{msign}$ calculation, resulting in Normalized SGDM. If you want to go further, such as with Multi-Head Attention, you might consider whether the projection matrix for each Head can be individually isolated to calculate $\text{msign}$...

Never stop tinkering!

Expected Results

Finally, if the user follows the instructions above, sets everything up correctly, and starts running, then they can begin praying for the arrival of the goddess of luck.

What kind of result should we expect? If there are no anomalies such as gradient explosion, Muon will usually be slightly better than Adam in most cases. Of course, it is not ruled out that Muon may be slightly worse in some situations, but in any case, the difference between them will not be very large. If one is significantly better than the other, it might be necessary to reconsider if there was a problem with the settings on either side.

However, this is not absolute. For example, under certain extreme settings, it is possible for Muon to be much better than Adam, with Adam failing to converge no matter how it is tuned. Conversely, the opposite is rarely encountered. In summary, I wish you good luck. If you encounter any interesting phenomena, you are welcome to discuss and analyze them together.


Source Link: https://kexue.fm/archives/11416

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants