Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions keras/src/optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Muon(optimizer.Optimizer):
The Muon optimizer can use both the Muon update step or the
AdamW update step based on the following:

- For any variable that isn't 2D, 3D or 4D, the AdamW step
- For any variable that isn't 2D, the AdamW step
will be used. This is not configurable.
- If the argument `exclude_embeddings` (defaults to `True`) is set
to `True`, the AdamW step will be used.
Expand All @@ -46,10 +46,12 @@ class Muon(optimizer.Optimizer):
that takes no arguments and returns the actual value to use.
The exponential decay rate for the 1st moment estimates. Defaults to
`0.9`.
adam_beta_2: A float value or a constant float tensor, ora callable
adam_beta_2: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use.
The exponential decay rate for the 2nd moment estimates. Defaults to
`0.999`.
adam_weight_decay: Float. If set, weight decay is applied when using
the Adam optimizer.
epsilon: A small constant for numerical stability. This is
"epsilon hat" in the Kingma and Ba paper
(in the formula just before Section 2.1),
Expand All @@ -67,20 +69,25 @@ class Muon(optimizer.Optimizer):
It is recommended to use the default value
adam_lr_ratio: Float, the ratio of the learning rate when
using Adam to the main learning rate.
it is recommended to set it to 0.1
It is recommended to set it to 1
momentum: Float, momentum used by internal SGD.
ns_steps: Integer, number of Newton-Schulz iterations to run.
nesterov: Boolean, whether to use Nesterov-style momentum
{{base_optimizer_keyword_args}}
rms_rate: Float. A parameter from https://arxiv.org/abs/2502.16982
that can enhance the stability of Muon, allowing it to use the
same learning rate and weight decay as Adam. Defaults to `0.2`.
Set to `None` to disable this feature.
"""

def __init__(
self,
learning_rate=0.001,
adam_beta_1=0.9,
adam_beta_2=0.999,
adam_weight_decay=0.004,
epsilon=1e-7,
weight_decay=0.1,
weight_decay=0.004,
clipnorm=None,
clipvalue=None,
global_clipnorm=None,
Expand All @@ -95,10 +102,11 @@ def __init__(
muon_a=3.4445,
muon_b=-4.7750,
muon_c=2.0315,
adam_lr_ratio=0.1,
adam_lr_ratio=1,
momentum=0.95,
ns_steps=6,
ns_steps=5,
nesterov=True,
rms_rate=0.2,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -127,12 +135,13 @@ def __init__(
self.nesterov = nesterov
self.exclude_embeddings = exclude_embeddings
self.exclude_layers = exclude_layers or []
self.adam_weight_decay = adam_weight_decay
self.rms_rate = rms_rate

def _should_use_adamw(self, variable):
# To use it with 4D convolutional filters,
# it works well to just flatten their last 3 dimensions.
# any {0,1}-D parameters should all be optimized by adam
if not 1 < len(variable.shape) < 4:
if len(variable.shape) != 2:
return True
Comment on lines +144 to 145
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following this change. In the Moonlight implementation the criteria for using Muon is that the ndim >= 2: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L296

So for AdamW, the criteria would be ndim < 2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following this change. In the Moonlight implementation the criteria for using Muon is that the : https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L296`ndim >= 2`

So for AdamW, the criteria would be .ndim < 2

The optimization target of Muon is matrices. In the 3D case, reshaping into matrices is necessary for effective optimization. However, this involves too many assumptions, and introducing it would only unnecessarily increase complexity. In fact, Muon never considered the case of CNNs. It was designed with only 1D-Transformer scenarios in mind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following this change. In the Moonlight implementation the criteria for using Muon is that the : https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L296`ndim >= 2`

So for AdamW, the criteria would be .ndim < 2

In the original implementation of MoonLight, they could ensure that the optimization target is a Transformer model based on PyTorch. However, in the Keras implementation, we cannot guarantee this. For example, in a typical case with the PyTorch backend, if we mix keras.layers.Dense and torch.nn.Linear, then the optimization targets would simultaneously include variables of shape [d_out, d_in] and [d_in, d_out].

Similarly, if the optimization target is a 3D CNN model, the parameter meanings for the CNN model differ between the "channels_last" and "channels_first" formats. We lack reasonable assumptions to perform reshaping in such cases.

The Muon optimizer in Keras should be a general-purpose optimizer, and a general-purpose optimizer should not rely on too many assumptions. Therefore, we can only use the most conservative approach: we do not optimize anything other than matrices.

This is also the reason why we do not use the Keller Jordan Version. The Keller Jordan Version assumes that the optimized matrix must be either [d_out, d_in] or [d_in, d_out], while MoonLight does not require such assumptions.

if self.exclude_embeddings and "embedding" in variable.path.lower():
return True
Expand Down Expand Up @@ -185,18 +194,13 @@ def update_step(self, gradient, variable, learning_rate):
def _muon_update_step(self, gradient, variable, lr):
m = self.adam_momentums[variable.path]
self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
shape = variable.shape
if self.nesterov:
g = ops.add(gradient, self.momentum * m)
else:
g = m
update = self.zeropower_via_newtonschulz5(g, self.ns_steps)

self.assign_sub(
variable,
lr
* self.zeropower_via_newtonschulz5(g, self.ns_steps)
* max(1, shape[0] / shape[1]) ** 0.5,
)
self.assign_sub(variable, self.lr_adjust(lr * update))

def _adamw_update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
Expand Down Expand Up @@ -239,6 +243,20 @@ def transpose_last_axis(self, X):
X = ops.transpose(X, temp_order)
return X

def lr_adjust(self, x):
"""Adjusts learning rate based on the Moonlight implementation.
This method enhances the stability of Muon, allowing it to use the same
learning rate and weight decay as Adam. For details, see
https://arxiv.org/abs/2502.16982.
For a 2D matrix, the update is scaled by `sqrt(max(n, m)) * rms_rate`,
where `n` and `m` are the dimensions of the matrix.
"""
if self.rms_rate is None:
return x
# moonlight version
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
return x * ops.sqrt(ops.maximum(x.shape[0], x.shape[1])) * self.rms_rate

def zeropower_via_newtonschulz5(self, x, steps: int):
"""We apply the Newton-Schulz iteration to compute matrix G.

Expand Down Expand Up @@ -268,6 +286,20 @@ def zeropower_via_newtonschulz5(self, x, steps: int):
x = self.transpose_last_axis(x)
return x

def _apply_weight_decay(self, variables):
for variable in variables:
if not self._use_weight_decay(variable):
continue
if self._should_use_adamw(variable):
weight_decay_value = self.adam_weight_decay
else:
weight_decay_value = self.weight_decay
if weight_decay_value is None:
continue
wd = ops.cast(weight_decay_value, variable.dtype)
lr = ops.cast(self.learning_rate, variable.dtype)
variable.assign(variable - variable * wd * lr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use self.assign(variable, variable - variable * wd * lr)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use self.assign(variable, variable - variable * wd * lr)

variable.assign(variable - variable * wd * lr)

Here, I maintain consistency with the existing weight decay implementation.


def get_config(self):
config = super().get_config()
config.update(
Expand All @@ -284,6 +316,8 @@ def get_config(self):
"ns_steps": self.ns_steps,
"nesterov": self.nesterov,
"exclude_embeddings": self.exclude_embeddings,
"adam_weight_decay": self.adam_weight_decay,
"rms_rate": self.rms_rate,
}
)
return config
40 changes: 36 additions & 4 deletions keras/src/optimizers/muon_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def test_should_use_adamw(self):
True,
optimizer._should_use_adamw(vars),
)
embeding = Embedding(2, 2)
embeding.build()
embedding = Embedding(2, 2)
embedding.build()
self.assertAllClose(
True,
optimizer._should_use_adamw(embeding.weights[0]),
optimizer._should_use_adamw(embedding.weights[0]),
)
vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
optimizer = Muon()
Expand All @@ -67,7 +67,10 @@ def test_muon_single_step(self):
optimizer.build([vars])
optimizer._muon_update_step(grads, vars, 0.5)
self.assertAllClose(
vars, [[1.13, 1.51], [2.57, 4.06]], rtol=1e-2, atol=1e-2
vars,
[[0.988775, 1.887053], [2.873428, 3.97035]],
rtol=1e-2,
atol=1e-2,
)

def test_clip_norm(self):
Expand All @@ -81,3 +84,32 @@ def test_clip_value(self):
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [1.0, 1.0])

def test_muon_weight_decay(self):
variable = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
weight_decay = 0.01
expected_variable = variable - variable * weight_decay
optimizer = Muon(learning_rate=1.0, weight_decay=weight_decay)
optimizer._apply_weight_decay([variable])
self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4)

def test_adamw_weight_decay(self):
variable = backend.Variable(2.0)
weight_decay = 0.01
expected_variable = variable - variable * weight_decay
optimizer = Muon(learning_rate=1.0, adam_weight_decay=weight_decay)
optimizer._apply_weight_decay([variable])

self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4)

def test_lr_adjust_none(self):
opt = Muon(rms_rate=None)
x = ops.ones((4, 4))
want = x
self.assertAllClose(opt.lr_adjust(x), want)

def test_lr_adjust_2d(self):
opt = Muon(rms_rate=0.2)
x = ops.ones((4, 2))
want = x * 0.2 * 2
self.assertAllClose(opt.lr_adjust(x), want)