From 220a7f8308fb2598b2801f9a24be1fc568ec93a2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 15 Oct 2025 16:30:39 -0700 Subject: [PATCH 1/5] Add impl of Muon optimizer. Fix #2580 --- tests/test_optim.py | 8 + timm/optim/__init__.py | 1 + timm/optim/_optim_factory.py | 9 + timm/optim/_param_groups.py | 64 ++++- timm/optim/muon.py | 542 +++++++++++++++++++++++++++++++++++ timm/optim/nadamw.py | 2 +- 6 files changed, 611 insertions(+), 15 deletions(-) create mode 100644 timm/optim/muon.py diff --git a/tests/test_optim.py b/tests/test_optim.py index f9aa287620..3c84117d34 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -394,6 +394,14 @@ def test_kron(optimizer): _test_model(optimizer, dict(lr=1e-3)) +@pytest.mark.parametrize('optimizer', ['muon']) +def test_muon(optimizer): + _test_rosenbrock( + lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) + ) + _test_model(optimizer, dict(lr=1e-3)) + + @pytest.mark.parametrize('optimizer', ['adopt', 'adoptw']) def test_adopt(optimizer): _test_rosenbrock( diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index 3727df1dcc..d5869cc206 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -13,6 +13,7 @@ from .lookahead import Lookahead from .madgrad import MADGRAD from .mars import Mars +from .muon import Muon from .nadam import NAdamLegacy from .nadamw import NAdamW from .nvnovograd import NvNovoGrad diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index 559f91e6fa..af9ec310ba 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -31,6 +31,7 @@ from .lookahead import Lookahead from .madgrad import MADGRAD from .mars import Mars +from .muon import Muon from .nadam import NAdamLegacy from .nadamw import NAdamW from .nvnovograd import NvNovoGrad @@ -871,6 +872,14 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None: description='Unleashing the Power of Variance Reduction for Training Large Models', has_betas=True, ), + OptimInfo( + name='muon', + opt_class=Muon, + description='MomentUm Orthogonalized by Newton-schulz with AdamW fallback for 1D params', + has_momentum=True, + has_eps=True, + has_betas=True, + ), OptimInfo( name='novograd', opt_class=NvNovoGrad, diff --git a/timm/optim/_param_groups.py b/timm/optim/_param_groups.py index 5048736ece..37588fa2b9 100644 --- a/timm/optim/_param_groups.py +++ b/timm/optim/_param_groups.py @@ -1,3 +1,4 @@ +import fnmatch import logging from itertools import islice from typing import Collection, Optional @@ -10,27 +11,54 @@ _logger = logging.getLogger(__name__) +def _matches_pattern(name: str, patterns: Collection[str]) -> bool: + """Check if parameter name matches any pattern (supports wildcards).""" + return any(fnmatch.fnmatch(name, pattern) for pattern in patterns) + + def param_groups_weight_decay( model: nn.Module, weight_decay: float = 1e-5, no_weight_decay_list: Collection[str] = (), + simple_params_list: Collection[str] = (), ): - no_weight_decay_list = set(no_weight_decay_list) decay = [] + decay_simple = [] no_decay = [] + no_decay_simple = [] for name, param in model.named_parameters(): if not param.requires_grad: continue - if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: - no_decay.append(param) + # Determine if this is a "simple" parameter for fallback optimizer (if available) + is_simple = _matches_pattern(name, no_weight_decay_list) + + # Determine weight decay + matches_pattern = _matches_pattern(name, no_weight_decay_list) + if param.ndim <= 1 or name.endswith(".bias") or matches_pattern: + # No weight decay + if is_simple: + no_decay_simple.append(param) + else: + no_decay.append(param) else: - decay.append(param) - - return [ - {'params': no_decay, 'weight_decay': 0.}, - {'params': decay, 'weight_decay': weight_decay}] - + # With weight decay + if is_simple: + decay_simple.append(param) + else: + decay.append(param) + + groups = [] + if decay: + groups.append({'params': decay, 'weight_decay': weight_decay}) + if decay_simple: + groups.append({'params': decay_simple, 'weight_decay': weight_decay, 'simple': True}) + if no_decay: + groups.append({'params': no_decay, 'weight_decay': 0.}) + if no_decay_simple: + groups.append({'params': no_decay_simple, 'weight_decay': 0., 'simple': True}) + + return groups def _group(it, size): it = iter(it) @@ -70,9 +98,9 @@ def param_groups_layer_decay( model: nn.Module, weight_decay: float = 0.05, no_weight_decay_list: Collection[str] = (), + simple_params_list: Collection[str] = (), weight_decay_exclude_1d: bool = True, layer_decay: float = .75, - end_layer_decay: Optional[float] = None, min_scale: float = 0., no_opt_scale: Optional[float] = None, verbose: bool = False, @@ -81,7 +109,6 @@ def param_groups_layer_decay( Parameter groups for layer-wise lr decay & weight decay Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 """ - no_weight_decay_list = set(no_weight_decay_list) param_group_names = {} # NOTE for debugging param_groups = {} @@ -99,8 +126,12 @@ def param_groups_layer_decay( if not param.requires_grad: continue - # no decay: all 1D parameters and model specific ones - if (weight_decay_exclude_1d and param.ndim <= 1) or name in no_weight_decay_list: + # Determine if this is a "simple" parameter for fallback optimizer (if available) + is_simple = _matches_pattern(name, simple_params_list) + + # Determine weight decay + if (weight_decay_exclude_1d and param.ndim <= 1) or _matches_pattern(name, no_weight_decay_list): + # no weight decay for 1D parameters and model specific ones g_decay = "no_decay" this_decay = 0. else: @@ -114,11 +145,14 @@ def param_groups_layer_decay( param.requires_grad = False continue - group_name = "layer_%d_%s" % (layer_id, g_decay) + simple_suffix = "_simple" if is_simple else "" + group_name = "layer_%d_%s%s" % (layer_id, g_decay, simple_suffix) + if group_name not in param_groups: param_group_names[group_name] = { "lr_scale": this_scale, "weight_decay": this_decay, + "simple": is_simple, "param_names": [], } param_groups[group_name] = { @@ -126,6 +160,8 @@ def param_groups_layer_decay( "weight_decay": this_decay, "params": [], } + if is_simple: + param_groups[group_name]["simple"] = True param_group_names[group_name]["param_names"].append(name) param_groups[group_name]["params"].append(param) diff --git a/timm/optim/muon.py b/timm/optim/muon.py new file mode 100644 index 0000000000..c6c54a048c --- /dev/null +++ b/timm/optim/muon.py @@ -0,0 +1,542 @@ +""" Muon Optimizer + +Improved Muon optimizer implementation with flexible handling of high-dimensional tensors. + +Combines PyTorch-style structure with options for: +- Batched spatial processing for convolutions in addition to flatten +- Optional spatial normalization +- Selectable coefficient presets +- Automatic fallback to AdamW for 1D / scalar parameters (biases, norms, etc.) and optional fallback via param groups + +Based on implementation by Keller Jordan, see +- https://github.com/KellerJordan/Muon/blob/master/muon.py +- https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py +- https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py +- https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py + +Hacked together by Ross Wightman +""" +import numbers +from typing import List, Mapping, Optional, Sequence, Tuple, Union + +import torch + +from ._types import ParamsT +from .adamw import adamw +from .nadamw import nadamw + +# Constants from Keller Jordan's Muon +MUON_EPS = 1e-7 +DEFAULT_NS_STEPS = 5 + +_COEFFICIENTS = { + "original": [ + # Keller Jordan's Muon https://kellerjordan.github.io/posts/muon/ + (3.4445, -4.7750, 2.0315), + ], + "quintic": [ + # https://leloykun.github.io/ponder/muon-opt-coeffs/#how-do-we-optimize-the-coefficients + # From https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py#L44 + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ], + "polar_express": [ + # Polar Express https://arxiv.org/abs/2505.16932 + # From https://github.com/NoahAmsel/PolarExpress/tree/main with safety 1e-2 + (8.237312490495555, -23.157747414558198, 16.680568411445915), + (4.082441999064835, -2.893047735332586, 0.5252849256975648), + (3.9263479922546582, -2.8547468034765298, 0.5318022422894988), + (3.2982187133085143, -2.424541981026706, 0.48632008358844075), + (2.2970369434552573, -1.63662558125903, 0.4002628455953627), + (1.8763805351440397, -1.2347896577722228, 0.35891887501668385), + (1.8564423485617974, -1.2132449880935525, 0.3568003487825883), + (1.8749994008682747, -1.2499988017229169, 0.3749994008546422), + ], + "polar_express_safer": [ + # from https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py + # w/ safety 2e-2 + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.0429299351667245, -2.808917465908704, 0.5000178451051299), + (3.8916678022926563, -2.7724841532176825, 0.5060648178503389), + (3.285753657755658, -2.3681294933425394, 0.46449024233003117), + (2.3005307116270983, -1.6111665557258408, 0.3833374427545273), + (1.8631210546382593, -1.2042160621002727, 0.3421879560523383), + (1.8382572152247512, -1.1779263289537742, 0.3396513038637379), + (1.8749999923301852, -1.2499999836060613, 0.374999991275876), + ], +} + + +NSCoeff = Union[str, Tuple[float, float, float], List[Tuple[float, float, float]]] + + +def zeropower_via_newtonschulz( + G: torch.Tensor, + steps: int, + coefficients: List[Tuple[float, float, float]], + eps: float = MUON_EPS, + safety_factor: float = 1.0, + dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Newton-Schulz quintic iteration to compute the zeroth power / orthogonalization of gradient. + + Supports batched operation over leading dimensions. + + See + - https://github.com/KellerJordan/Muon/blob/master/muon.py + - https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py + - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py + + Args: + G: Input gradient tensor of shape (m, n) or (batch, m, n) + steps: Number of Newton-Schulz iterations + coefficients: Coefficients (a, b, c) for the iteration + eps: Numerical stability epsilon for norm + safety_factor: Multiplicative safety factor for norm (1.01 is common safety value) + dtype: Computation dtype + + Returns: + Orthogonalized tensor of same shape as G + """ + assert G.ndim in (2, 3), f"Input must be 2D or 3D, got {G.ndim}D. Flatten batch dims first." + num_cs = len(coefficients) + assert num_cs >= 1 and len(coefficients[0]) == 3 + # match coefficients with # of steps, truncate or repeat last + coeff_sequence = coefficients[:steps] if steps <= num_cs else \ + coefficients + [coefficients[-1]] * (steps - num_cs) + + X = G.to(dtype=dtype, copy=True) + + # Transpose if needed (operate on dimension with fewer elements) + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + + # Normalize spectral norm to at most 1 + X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).clamp_min(eps)) + + # Batched vs unbatched fused MM + mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Pre-allocate + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + # Perform Newton-Schulz iterations + for a, b, c in coeff_sequence: + mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT + mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A + mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X + X, C = C, X # swap refs to avoid copy + + if transposed: + X = X.mT + + return X + + +def get_lr_scale( + param_shape: torch.Size, + adjust_lr_fn: str = "match_rms_adamw" +) -> float: + """Adjust learning rate based on parameter shape.""" + out_chs, in_chs = (param_shape[-2], param_shape[-1]) if len(param_shape) > 1 else (1., 1.) + + if adjust_lr_fn == "original": + # Original Muon impl (https://kellerjordan.github.io/posts/muon/) + return max(1, out_chs / in_chs) ** 0.5 + elif adjust_lr_fn == "match_rms_adamw": + # Kimi (https://arxiv.org/abs/2502.16982) + return 0.2 * max(out_chs, in_chs) ** 0.5 + elif adjust_lr_fn == "rms_to_rms": + # Scion (https://arxiv.org/abs/2502.07529, https://github.com/LIONS-EPFL/scion) + # Bernstein et al. (https://jeremybernste.in/writing/deriving-muon) + return (out_chs / in_chs) ** 0.5 + else: + assert False, f'Invalid scaling function "{adjust_lr_fn}"' + + +def reshape_for_muon( + tensor: torch.Tensor, + mode: str = "flatten", +) -> Tuple[torch.Tensor, torch.Size]: + """Reshape high-dimensional tensor for Muon processing. + + Args: + tensor: Input tensor of shape (out, in, *spatial) + mode: How to handle spatial dimensions + - "flatten": Flatten spatial into output dimension (out, in*H*W) + - "batched": Batch over spatial positions (spatial_prod, out, in) for per-position orthogonalization + + Returns: + Reshaped tensor and original shape for restoration + """ + original_shape = tensor.shape + if tensor.ndim == 2: + return tensor, original_shape + if tensor.ndim < 2: + raise ValueError(f"Tensor must have at least 2 dimensions, got {tensor.ndim}") + + out_ch, in_ch = tensor.shape[:2] + if mode == "flatten": + # Flatten: (out, in, *spatial) -> (out, in * spatial_prod) + return tensor.reshape(out_ch, -1), original_shape + elif mode == "batched": + # Batched: (out, in, *spatial) -> (spatial_prod, out, in) + # Move spatial dimension to front so zeropower_via_newtonschulz batches over it + reshaped = tensor.reshape(out_ch, in_ch, -1) # (out, in, spatial_prod) + reshaped = reshaped.permute(2, 0, 1) # (spatial_prod, out, in) + return reshaped, original_shape + else: + raise ValueError(f"Unknown mode: {mode}") + + +def muon( + params: List[torch.Tensor], + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + *, + lr: float, + weight_decay: float, + momentum: float, + nesterov: bool, + ns_steps: int, + ns_coefficients: NSCoeff, + eps: float, + safety_factor: float, + adjust_lr_fn: Optional[str], + conv_mode: str, + normalize_spatial: bool, +) -> None: + """Functional API that performs Muon algorithm computation.""" + _single_tensor_muon( + params, + grads, + momentum_bufs, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + ns_coefficients=ns_coefficients, + eps=eps, + safety_factor=safety_factor, + adjust_lr_fn=adjust_lr_fn, + conv_mode=conv_mode, + normalize_spatial=normalize_spatial, + ) + + +def _single_tensor_muon( + params: List[torch.Tensor], + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + *, + lr: float, + weight_decay: float, + momentum: float, + nesterov: bool, + ns_steps: int, + ns_coefficients: NSCoeff, + eps: float, + safety_factor: float, + adjust_lr_fn: Optional[str], + conv_mode: str, + normalize_spatial: bool, +) -> None: + """Single tensor Muon update.""" + ns_coefficients = resolve_ns_coefficients(ns_coefficients, _COEFFICIENTS) + + for i, param in enumerate(params): + grad = grads[i] + momentum_buf = momentum_bufs[i] + + # Apply weight decay + param.mul_(1 - lr * weight_decay) + + # Update momentum buffer + momentum_buf.lerp_(grad, 1. - momentum) + update = grad.lerp_(momentum_buf, momentum) if nesterov else momentum_buf.clone() + + # Reshape for processing (handle 3D+ tensors like conv weights) + if update.ndim >= 3: + update_reshaped, original_shape = reshape_for_muon(update, mode=conv_mode) + else: + update_reshaped = update + original_shape = update.shape + + # Apply Newton-Schulz orthogonalization + update_ortho = zeropower_via_newtonschulz( + update_reshaped, + ns_steps, + ns_coefficients, + eps=eps, + safety_factor=safety_factor, + #dtype=torch.bfloat16, # wire to arg? + ) + + # Adjust learning rate based on parameter shape + scale = get_lr_scale(update_ortho.shape, adjust_lr_fn) + + # Apply spatial normalization and permute back if in batched mode + if conv_mode == "batched" and update_ortho.ndim >= 3: + if normalize_spatial: + scale *= update_ortho.shape[0] ** -0.5 + # Permute back: (spatial_prod, out, in) -> (out, in, spatial_prod) + update_ortho = update_ortho.permute(1, 2, 0) + + # Reshape back to original shape + update_ortho = update_ortho.reshape(original_shape) + + # Apply update + param.add_(update_ortho, alpha=-lr * scale) + + +class Muon(torch.optim.Optimizer): + """Muon - MomentUm Orthogonalized by Newton-schulz + + Combines Muon for 2D+ parameters (weight matrices) with AdamW for 1D parameters (biases, norms) and + parameter groups with 'simple=True' set. + """ + + def __init__( + self, + params: ParamsT, + lr: float = 0.02, + weight_decay: float = 0, + momentum: float = 0.95, + nesterov: bool = True, + ns_steps: int = DEFAULT_NS_STEPS, + ns_coefficients: NSCoeff = "quintic", + eps: float = MUON_EPS, + safety_factor: float = 1.0, + adjust_lr_fn: Optional[str] = "match_rms_adamw", + conv_mode: str = "flatten", + normalize_spatial: bool = True, + adamw_lr: Optional[float] = None, + betas: Tuple[float, float] = (0.9, 0.95), + ): + """ Create Muon optimizer. + Args: + params: Iterable of parameters or dicts defining parameter groups + lr: Learning rate (default: 0.02 for Muon parameters) + weight_decay: Weight decay coefficient + momentum: Momentum factor for Muon + nesterov: Whether to use Nesterov momentum + ns_steps: Number of Newton-Schulz iterations + ns_coefficients: Coefficients for NS iteration + eps: Numerical stability epsilon + safety_factor: Multiplicative safety factor for NS norm + adjust_lr_fn: LR adjustment function - "original" or "match_rms_adamw" + conv_mode: How to handle convolutions - "flatten" or "batched" + normalize_spatial: Whether to normalize by sqrt(spatial_size) in batched mode + adamw_lr: Learning rate for AdamW (1D params), defaults to lr if not specified + betas: AdamW beta coefficients + + Example: + ```python + # Simple usage - automatically uses Muon for 2D+ params, AdamW for 1D + optimizer = Muon(model.parameters(), lr=0.02) + + # Manual control over parameter groups + optimizer = Muon([ + {'params': weight_matrices, 'lr': 0.02}, + {'params': biases, 'simple': True, 'lr': 3e-4}, # use AdamW if simple=True + ]) + ``` + """ + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= momentum < 1.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if conv_mode not in ["flatten", "batched"]: + raise ValueError(f"Invalid conv_mode: {conv_mode}") + + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + ns_coefficients=ns_coefficients, + eps=eps, + safety_factor=safety_factor, + adjust_lr_fn=adjust_lr_fn, + conv_mode=conv_mode, + normalize_spatial=normalize_spatial, + adamw_lr=adamw_lr if adamw_lr is not None else lr, + betas=betas, + ) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + # Separate params into Muon and AdamW groups + muon_params = [] + muon_grads = [] + muon_momentum_bufs = [] + + adamw_params = [] + adamw_grads = [] + adamw_exp_avgs = [] + adamw_exp_avg_sqs = [] + adamw_state_steps = [] + + for p in group["params"]: + if p.grad is None: + continue + + if p.grad.is_sparse: + raise RuntimeError("Muon does not support sparse gradients") + + # Determine if we should use Muon or AdamW fallback + force_adamw = p.ndim < 2 or group.get("simple", False) + + state = self.state[p] + + if force_adamw: + # Collect AdamW/NAdamW params + adamw_params.append(p) + adamw_grads.append(p.grad) + + # State initialization + if len(state) == 0: + state["step"] = torch.tensor(0.) + state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + adamw_exp_avgs.append(state["exp_avg"]) + adamw_exp_avg_sqs.append(state["exp_avg_sq"]) + adamw_state_steps.append(state["step"]) + else: + # Collect Muon params + muon_params.append(p) + muon_grads.append(p.grad) + + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) + muon_momentum_bufs.append(state["momentum_buffer"]) + + # Apply Muon updates + if muon_params: + muon( + muon_params, + muon_grads, + muon_momentum_bufs, + lr=group["lr"], + weight_decay=group["weight_decay"], + momentum=group["momentum"], + nesterov=group["nesterov"], + ns_steps=group["ns_steps"], + ns_coefficients=group["ns_coefficients"], + eps=group["eps"], + safety_factor=group["safety_factor"], + adjust_lr_fn=group["adjust_lr_fn"], + conv_mode=group["conv_mode"], + normalize_spatial=group["normalize_spatial"], + ) + + # Apply AdamW updates + if adamw_params: + beta1, beta2 = group["betas"] + if group["nesterov"]: + # use nadamw for fallback optimizer if nesterov is enabled + nadamw( + adamw_params, + adamw_grads, + adamw_exp_avgs, + adamw_exp_avg_sqs, + adamw_state_steps, + foreach=None, + beta1=beta1, + beta2=beta2, + lr=group["adamw_lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + caution=False, + maximize=False, + capturable=False, + max_lr=None, + ) + else: + adamw( + adamw_params, + adamw_grads, + adamw_exp_avgs, + adamw_exp_avg_sqs, + [], # max_exp_avg_sqs (not using amsgrad) + adamw_state_steps, + foreach=None, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=group["adamw_lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + caution=False, + maximize=False, + capturable=False, + max_lr=None, + ) + + return loss + + +def resolve_ns_coefficients( + value: Union[str, Sequence[float], Sequence[Sequence[float]]], + presets: Mapping[str, Sequence[Sequence[float]]] +) -> List[Tuple[float, float, float]]: + # tiny helpers (kept inline for succinctness) + is_seq = lambda x: isinstance(x, Sequence) and not isinstance(x, (str, bytes)) + is_real = lambda x: isinstance(x, numbers.Real) and not isinstance(x, bool) + + def as_coeff(x: Sequence[float]) -> Tuple[float, float, float]: + if not is_seq(x) or len(x) != 3 or not all(is_real(v) for v in x): + raise ValueError(f"Coefficient must be length-3 of real numbers, got: {x!r}") + a, b, c = x # type: ignore[misc] + return float(a), float(b), float(c) + + if isinstance(value, str): + if value not in presets: + valid = ", ".join(sorted(presets.keys())) + raise ValueError(f"Unknown coefficients preset '{value}'. Valid options: {valid}") + seq = presets[value] + if not is_seq(seq) or len(seq) == 0: + raise ValueError(f"Preset '{value}' is empty or invalid") + return [as_coeff(item) for item in seq] # validate & cast + + if not is_seq(value): + raise TypeError( + "Coefficients must be a preset name (str), a 3-sequence (a,b,c), " + "or a sequence of 3-sequences." + ) + + # Decide single triple vs list-of-triples by structure + if len(value) == 3 and all(is_real(v) for v in value): # type: ignore[index] + return [as_coeff(value)] # single triple -> wrap + + # Otherwise treat as list/tuple of triples + out = [] + for i, item in enumerate(value): # type: ignore[assignment] + if not is_seq(item): + raise TypeError(f"Item {i} is not a sequence: {item!r}") + out.append(as_coeff(item)) + if not out: + raise ValueError("Coefficient list cannot be empty") + return out \ No newline at end of file diff --git a/timm/optim/nadamw.py b/timm/optim/nadamw.py index c8027408f5..d2de7754e9 100644 --- a/timm/optim/nadamw.py +++ b/timm/optim/nadamw.py @@ -167,7 +167,7 @@ def nadamw( eps: float, caution: bool, maximize: bool, - max_lr: float, + max_lr: Optional[float], ) -> None: r"""Functional API that performs NAdamW algorithm computation. See NAdamW class for details. From d73ebb6c049050cd87809d68e3c296632b4a71b2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 15 Oct 2025 16:38:15 -0700 Subject: [PATCH 2/5] Fixed type in param group helpers --- timm/optim/_param_groups.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/optim/_param_groups.py b/timm/optim/_param_groups.py index 37588fa2b9..6035675149 100644 --- a/timm/optim/_param_groups.py +++ b/timm/optim/_param_groups.py @@ -31,7 +31,7 @@ def param_groups_weight_decay( continue # Determine if this is a "simple" parameter for fallback optimizer (if available) - is_simple = _matches_pattern(name, no_weight_decay_list) + is_simple = _matches_pattern(name, simple_params_list) # Determine weight decay matches_pattern = _matches_pattern(name, no_weight_decay_list) From da901240c1f3fea19f5f417e987e94a7d0b4e91c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 15 Oct 2025 19:04:34 -0700 Subject: [PATCH 3/5] Re-order decay/no-decay groups to match old order and pass existing test. Change end lr decay test to min scale. --- tests/test_optim.py | 6 +++--- timm/optim/_param_groups.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index 3c84117d34..e7211bbf8f 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -552,7 +552,7 @@ def test_lookahead_radam(optimizer): ) -def test_param_groups_layer_decay_with_end_decay(): +def test_param_groups_layer_decay_with_min(): model = torch.nn.Sequential( torch.nn.Linear(10, 5), torch.nn.ReLU(), @@ -563,12 +563,12 @@ def test_param_groups_layer_decay_with_end_decay(): model, weight_decay=0.05, layer_decay=0.75, - end_layer_decay=0.5, + min_scale=0.5, verbose=True ) assert len(param_groups) > 0 - # Verify layer scaling is applied with end decay + # Verify layer scaling is applied with a min scale for group in param_groups: assert 'lr_scale' in group assert group['lr_scale'] <= 1.0 diff --git a/timm/optim/_param_groups.py b/timm/optim/_param_groups.py index 6035675149..1526d31d14 100644 --- a/timm/optim/_param_groups.py +++ b/timm/optim/_param_groups.py @@ -49,14 +49,14 @@ def param_groups_weight_decay( decay.append(param) groups = [] - if decay: - groups.append({'params': decay, 'weight_decay': weight_decay}) - if decay_simple: - groups.append({'params': decay_simple, 'weight_decay': weight_decay, 'simple': True}) if no_decay: groups.append({'params': no_decay, 'weight_decay': 0.}) + if decay: + groups.append({'params': decay, 'weight_decay': weight_decay}) if no_decay_simple: groups.append({'params': no_decay_simple, 'weight_decay': 0., 'simple': True}) + if decay_simple: + groups.append({'params': decay_simple, 'weight_decay': weight_decay, 'simple': True}) return groups From 40ee84ad733d70d1da881ecd85db07fe79847898 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 16 Oct 2025 15:19:39 -0700 Subject: [PATCH 4/5] Add more advanced heuristics for determining muon optimization suitability for params. Track decision in use_muon state --- timm/optim/_optim_factory.py | 8 ++ timm/optim/_param_groups.py | 10 ++ timm/optim/muon.py | 173 +++++++++++++++++++++++++++++++---- 3 files changed, 174 insertions(+), 17 deletions(-) diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index af9ec310ba..ab99c98513 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -234,6 +234,7 @@ def create_optimizer( momentum: float = 0.9, foreach: Optional[bool] = None, weight_decay_exclude_1d: bool = True, + simple_no_weight_decay: bool = False, layer_decay: Optional[float] = None, layer_decay_min_scale: Optional[float] = None, layer_decay_no_opt_scale: Optional[float] = None, @@ -250,6 +251,7 @@ def create_optimizer( momentum: Momentum factor for applicable optimizers foreach: Enable/disable foreach operation weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine) + simple_no_weight_decay: If True, params in no_weight_decay list will use simple/fallback optimizer (e.g., AdamW for Muon) layer_decay: Layer-wise learning rate decay layer_scale_min_scale: Minimum layer scale factor clamp value layer_scale_no_opt_scale: Layer scale below which optimization is disabled @@ -277,6 +279,7 @@ def create_optimizer( weight_decay=weight_decay, layer_decay=layer_decay, no_weight_decay_list=no_weight_decay, + simple_no_weight_decay=simple_no_weight_decay, weight_decay_exclude_1d=weight_decay_exclude_1d, min_scale=layer_decay_min_scale, no_opt_scale=layer_decay_no_opt_scale, @@ -287,6 +290,7 @@ def create_optimizer( model_or_params, weight_decay=weight_decay, no_weight_decay_list=no_weight_decay, + simple_no_weight_decay=simple_no_weight_decay, ) weight_decay = 0. else: @@ -1154,6 +1158,7 @@ def create_optimizer_v2( momentum: float = 0.9, foreach: Optional[bool] = None, filter_bias_and_bn: bool = True, + simple_no_weight_decay: bool = False, layer_decay: Optional[float] = None, layer_decay_min_scale: float = 0.0, layer_decay_no_opt_scale: Optional[float] = None, @@ -1181,6 +1186,8 @@ def create_optimizer_v2( filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have weight decay applied. Only used when model_or_params is a model and weight_decay > 0. + simple_no_weight_decay: If True, params in model's no_weight_decay() list will use + simple/fallback optimizer for hybrid optimizers (e.g., AdamW for Muon). layer_decay: Optional layer-wise learning rate decay factor. If provided, learning rates will be scaled by layer_decay^(max_depth - layer_depth). Only used when model_or_params is a model. @@ -1231,6 +1238,7 @@ def create_optimizer_v2( momentum=momentum, foreach=foreach, weight_decay_exclude_1d=filter_bias_and_bn, + simple_no_weight_decay=simple_no_weight_decay, layer_decay=layer_decay, layer_decay_min_scale=layer_decay_min_scale, layer_decay_no_opt_scale=layer_decay_no_opt_scale, diff --git a/timm/optim/_param_groups.py b/timm/optim/_param_groups.py index 1526d31d14..d615669006 100644 --- a/timm/optim/_param_groups.py +++ b/timm/optim/_param_groups.py @@ -21,7 +21,12 @@ def param_groups_weight_decay( weight_decay: float = 1e-5, no_weight_decay_list: Collection[str] = (), simple_params_list: Collection[str] = (), + simple_no_weight_decay: bool = False, ): + # Merge no_weight_decay into simple_params if requested + if simple_no_weight_decay: + simple_params_list = set(simple_params_list) | set(no_weight_decay_list) + decay = [] decay_simple = [] no_decay = [] @@ -99,6 +104,7 @@ def param_groups_layer_decay( weight_decay: float = 0.05, no_weight_decay_list: Collection[str] = (), simple_params_list: Collection[str] = (), + simple_no_weight_decay: bool = False, weight_decay_exclude_1d: bool = True, layer_decay: float = .75, min_scale: float = 0., @@ -109,6 +115,10 @@ def param_groups_layer_decay( Parameter groups for layer-wise lr decay & weight decay Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 """ + # Merge no_weight_decay into simple_params if requested + if simple_no_weight_decay: + simple_params_list = set(simple_params_list) | set(no_weight_decay_list) + param_group_names = {} # NOTE for debugging param_groups = {} diff --git a/timm/optim/muon.py b/timm/optim/muon.py index c6c54a048c..43856d60f3 100644 --- a/timm/optim/muon.py +++ b/timm/optim/muon.py @@ -16,6 +16,7 @@ Hacked together by Ross Wightman """ +import logging import numbers from typing import List, Mapping, Optional, Sequence, Tuple, Union @@ -25,6 +26,8 @@ from .adamw import adamw from .nadamw import nadamw +_logger = logging.getLogger(__name__) + # Constants from Keller Jordan's Muon MUON_EPS = 1e-7 DEFAULT_NS_STEPS = 5 @@ -95,7 +98,7 @@ def zeropower_via_newtonschulz( steps: Number of Newton-Schulz iterations coefficients: Coefficients (a, b, c) for the iteration eps: Numerical stability epsilon for norm - safety_factor: Multiplicative safety factor for norm (1.01 is common safety value) + safety_factor: Multiplicative safety factor for norm (1.01 is common safety value in 'polar express' variants) dtype: Computation dtype Returns: @@ -161,6 +164,78 @@ def get_lr_scale( assert False, f'Invalid scaling function "{adjust_lr_fn}"' +def _is_suitable_for_muon( + param: torch.Tensor, + min_dim_size: int = 4, + max_aspect_ratio: float = 128., + return_reason: bool = False, +) -> Union[bool, Tuple[bool, str]]: + """Check if a parameter is suitable for Muon optimization. + + Args: + param: Parameter tensor + min_dim_size: Minimum size for non-unit dimensions + max_aspect_ratio: Maximum allowed aspect ratio + return_reason: If True, return (bool, reason_string), else just bool (faster) + + Returns: + If return_reason=False: bool indicating suitability + If return_reason=True: Tuple of (is_suitable, reason_string) + + Examples: + (64, 128) -> True (or (True, "ok") if return_reason=True) + (96, 3, 4, 4) -> True - will be flattened to (96, 48) + (4, 2048) -> False - extreme aspect ratio + (64,) -> False - insufficient dims + (1, 196, 768) -> False - leading unit dims + + NOTE: these rules were created to balance complexity with covering common timm model cases + Please let me know if there are non-optimal cases that you run into. + """ + + s = param.shape + # Must have at least 2 non-unit dimensions + if param.ndim < 2 or sum(1 for dim_size in s if dim_size > 1) < 2: + return (False, "insufficient_dims") if return_reason else False + + # Unit dimension in first two positions indicates: + # - Position embeddings (1, seq, dim) + # - Depthwise convs (out, 1, h, w) + # - Other degenerate cases possibly not caught by first rule + if s[0] == 1 or s[1] == 1: + return (False, "leading_unit_dims") if return_reason else False + + if param.ndim >= 3: + # For 3D+ tensors, check what dimensions will be AFTER flattening + # since that's what gets passed to Newton-Schulz iteration + # Flatten mode: (out, in, *spatial) -> (out, in * spatial_prod) + out_ch = s[0] + in_ch_with_spatial = 1 + for d in s[1:]: + in_ch_with_spatial *= d + check_dims = (out_ch, in_ch_with_spatial) + else: + # For 2D tensors, check as-is + check_dims = s + + # Both dims should be >= minimum size + min_size = min(check_dims) + if min_size < min_dim_size: + if return_reason: + return False, f"min_dim_too_small:{min_size}" + return False + + # Aspect ratio shouldn't be too extreme + max_size = max(check_dims) + aspect_ratio = max_size / min_size + if aspect_ratio > max_aspect_ratio: + if return_reason: + return False, f"extreme_aspect_ratio:{aspect_ratio:.1f}" + return False + + return (True, "ok") if return_reason else True + + def reshape_for_muon( tensor: torch.Tensor, mode: str = "flatten", @@ -320,6 +395,7 @@ def __init__( normalize_spatial: bool = True, adamw_lr: Optional[float] = None, betas: Tuple[float, float] = (0.9, 0.95), + verbose: bool = False, ): """ Create Muon optimizer. Args: @@ -337,6 +413,7 @@ def __init__( normalize_spatial: Whether to normalize by sqrt(spatial_size) in batched mode adamw_lr: Learning rate for AdamW (1D params), defaults to lr if not specified betas: AdamW beta coefficients + verbose: Log parameter routing decisions (Muon vs AdamW) Example: ```python @@ -375,6 +452,7 @@ def __init__( normalize_spatial=normalize_spatial, adamw_lr=adamw_lr if adamw_lr is not None else lr, betas=betas, + verbose=verbose, ) super().__init__(params, defaults) @@ -386,6 +464,13 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() + verbose = self.defaults.get("verbose", False) + + # Tracking for logging (populated on first encounter of each param) + muon_count = 0 + adamw_count = 0 + routing_reasons = {} if verbose else None + for group in self.param_groups: # Separate params into Muon and AdamW groups muon_params = [] @@ -405,18 +490,51 @@ def step(self, closure=None): if p.grad.is_sparse: raise RuntimeError("Muon does not support sparse gradients") - # Determine if we should use Muon or AdamW fallback - force_adamw = p.ndim < 2 or group.get("simple", False) - state = self.state[p] - if force_adamw: + # Determine routing on first encounter (cache in state) + if "use_muon" not in state: + # Check explicit simple flag first + reason = None + if group.get("simple", False): + state["use_muon"] = False + if verbose: + reason = "simple_flag" + else: + # Check shape suitability + if verbose: + suitable, reason = _is_suitable_for_muon(p, return_reason=True) + else: + suitable = _is_suitable_for_muon(p, return_reason=False) + state["use_muon"] = suitable + + # Track routing decision for logging + if routing_reasons is not None and reason is not None: + shape_str = "x".join(str(s) for s in p.shape) + if shape_str not in routing_reasons: + routing_reasons[shape_str] = [] + routing_reasons[shape_str].append(reason) + + # Use cached routing decision + use_muon = state["use_muon"] + if use_muon: + # Collect Muon params + muon_params.append(p) + muon_grads.append(p.grad) + muon_count += 1 + + # State initialization for Muon + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) + muon_momentum_bufs.append(state["momentum_buffer"]) + else: # Collect AdamW/NAdamW params adamw_params.append(p) adamw_grads.append(p.grad) + adamw_count += 1 - # State initialization - if len(state) == 0: + # State initialization for AdamW + if "step" not in state: state["step"] = torch.tensor(0.) state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) @@ -424,14 +542,6 @@ def step(self, closure=None): adamw_exp_avgs.append(state["exp_avg"]) adamw_exp_avg_sqs.append(state["exp_avg_sq"]) adamw_state_steps.append(state["step"]) - else: - # Collect Muon params - muon_params.append(p) - muon_grads.append(p.grad) - - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) - muon_momentum_bufs.append(state["momentum_buffer"]) # Apply Muon updates if muon_params: @@ -495,12 +605,41 @@ def step(self, closure=None): max_lr=None, ) + # Log routing summary when we have new routing decisions + if routing_reasons and len(routing_reasons) > 0: + # Concise summary + _logger.info(f"Muon parameter routing: {muon_count} Muon, {adamw_count} AdamW") + + # Group by reason for detailed breakdown + reason_groups = {} + for shape_str, reasons in sorted(routing_reasons.items()): + for reason in reasons: + if reason not in reason_groups: + reason_groups[reason] = [] + reason_groups[reason].append(shape_str) + + # Log summary counts per reason + reason_summary = [] + for reason, shapes in sorted(reason_groups.items()): + reason_summary.append(f"{reason}={len(shapes)}") + _logger.info(f" Breakdown: {', '.join(reason_summary)}") + + # Detailed breakdown at INFO level + if _logger.isEnabledFor(logging.INFO): + for reason, shapes in sorted(reason_groups.items()): + optimizer_name = "Muon" if reason == "ok" else "AdamW" + _logger.info(f" {reason} -> {optimizer_name}:") + for shape in shapes[:10]: + _logger.info(f" {shape}") + if len(shapes) > 10: + _logger.info(f" ... and {len(shapes) - 10} more") + return loss def resolve_ns_coefficients( - value: Union[str, Sequence[float], Sequence[Sequence[float]]], - presets: Mapping[str, Sequence[Sequence[float]]] + value: Union[str, Sequence[float], Sequence[Sequence[float]]], + presets: Mapping[str, Sequence[Sequence[float]]] ) -> List[Tuple[float, float, float]]: # tiny helpers (kept inline for succinctness) is_seq = lambda x: isinstance(x, Sequence) and not isinstance(x, (str, bytes)) From 74d384e686502bdf90e598da3f7f40b5b1cfb15f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 16 Oct 2025 15:39:34 -0700 Subject: [PATCH 5/5] Remove nesterov=True default for Muon, make a 'nmuon' name that maps to Muon with nesterov=True binding --- tests/test_optim.py | 2 +- timm/optim/_optim_factory.py | 9 +++++++++ timm/optim/muon.py | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index e7211bbf8f..995e4fd8fb 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -394,7 +394,7 @@ def test_kron(optimizer): _test_model(optimizer, dict(lr=1e-3)) -@pytest.mark.parametrize('optimizer', ['muon']) +@pytest.mark.parametrize('optimizer', ['muon', 'nmuon']) def test_muon(optimizer): _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index ab99c98513..7de3eda38b 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -884,6 +884,15 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None: has_eps=True, has_betas=True, ), + OptimInfo( + name='nmuon', + opt_class=Muon, + description='MomentUm Orthogonalized by Newton-schulz with Nesterov and NAdamW fallback for 1D params', + has_momentum=True, + has_eps=True, + has_betas=True, + defaults={'nesterov': True} + ), OptimInfo( name='novograd', opt_class=NvNovoGrad, diff --git a/timm/optim/muon.py b/timm/optim/muon.py index 43856d60f3..b2f13aa404 100644 --- a/timm/optim/muon.py +++ b/timm/optim/muon.py @@ -385,7 +385,7 @@ def __init__( lr: float = 0.02, weight_decay: float = 0, momentum: float = 0.95, - nesterov: bool = True, + nesterov: bool = False, ns_steps: int = DEFAULT_NS_STEPS, ns_coefficients: NSCoeff = "quintic", eps: float = MUON_EPS,