From 59f75ca3ef571e89ecce2d34ec2ef49f4f41a9ea Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 20 Oct 2025 14:29:22 -0700 Subject: [PATCH] Rename 'simple' flag for Muon to 'fallback', add support for inverted 'use_muon' to be compat with other Muon impl. Add fallback_list arg to full optim factory call chain. --- timm/optim/_optim_factory.py | 26 ++++++++++------ timm/optim/_param_groups.py | 58 ++++++++++++++++++------------------ timm/optim/muon.py | 16 ++++++---- 3 files changed, 57 insertions(+), 43 deletions(-) diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index 7de3eda38b..4d99406c75 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -5,7 +5,7 @@ import logging from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Type, Union from fnmatch import fnmatch import importlib @@ -234,7 +234,8 @@ def create_optimizer( momentum: float = 0.9, foreach: Optional[bool] = None, weight_decay_exclude_1d: bool = True, - simple_no_weight_decay: bool = False, + fallback_list: Collection[str] = (), + fallback_no_weight_decay: bool = False, layer_decay: Optional[float] = None, layer_decay_min_scale: Optional[float] = None, layer_decay_no_opt_scale: Optional[float] = None, @@ -251,7 +252,8 @@ def create_optimizer( momentum: Momentum factor for applicable optimizers foreach: Enable/disable foreach operation weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine) - simple_no_weight_decay: If True, params in no_weight_decay list will use simple/fallback optimizer (e.g., AdamW for Muon) + fallback_list: Collection of parameter name patterns to use fallback optimizer for hybrid optimizers + fallback_no_weight_decay: If True, params in no_weight_decay list will use fallback optimizer (e.g., AdamW for Muon) layer_decay: Layer-wise learning rate decay layer_scale_min_scale: Minimum layer scale factor clamp value layer_scale_no_opt_scale: Layer scale below which optimization is disabled @@ -279,7 +281,8 @@ def create_optimizer( weight_decay=weight_decay, layer_decay=layer_decay, no_weight_decay_list=no_weight_decay, - simple_no_weight_decay=simple_no_weight_decay, + fallback_list=fallback_list, + fallback_no_weight_decay=fallback_no_weight_decay, weight_decay_exclude_1d=weight_decay_exclude_1d, min_scale=layer_decay_min_scale, no_opt_scale=layer_decay_no_opt_scale, @@ -290,7 +293,8 @@ def create_optimizer( model_or_params, weight_decay=weight_decay, no_weight_decay_list=no_weight_decay, - simple_no_weight_decay=simple_no_weight_decay, + fallback_list=fallback_list, + fallback_no_weight_decay=fallback_no_weight_decay, ) weight_decay = 0. else: @@ -1167,7 +1171,8 @@ def create_optimizer_v2( momentum: float = 0.9, foreach: Optional[bool] = None, filter_bias_and_bn: bool = True, - simple_no_weight_decay: bool = False, + fallback_list: Collection[str] = (), + fallback_no_weight_decay: bool = False, layer_decay: Optional[float] = None, layer_decay_min_scale: float = 0.0, layer_decay_no_opt_scale: Optional[float] = None, @@ -1195,8 +1200,10 @@ def create_optimizer_v2( filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have weight decay applied. Only used when model_or_params is a model and weight_decay > 0. - simple_no_weight_decay: If True, params in model's no_weight_decay() list will use - simple/fallback optimizer for hybrid optimizers (e.g., AdamW for Muon). + fallback_list: Collection of parameter name patterns to use fallback optimizer for + hybrid optimizers (e.g., AdamW for Muon). Supports wildcard matching. + fallback_no_weight_decay: If True, params in model's no_weight_decay() list will use + fallback optimizer for hybrid optimizers (e.g., AdamW for Muon). layer_decay: Optional layer-wise learning rate decay factor. If provided, learning rates will be scaled by layer_decay^(max_depth - layer_depth). Only used when model_or_params is a model. @@ -1247,7 +1254,8 @@ def create_optimizer_v2( momentum=momentum, foreach=foreach, weight_decay_exclude_1d=filter_bias_and_bn, - simple_no_weight_decay=simple_no_weight_decay, + fallback_list=fallback_list, + fallback_no_weight_decay=fallback_no_weight_decay, layer_decay=layer_decay, layer_decay_min_scale=layer_decay_min_scale, layer_decay_no_opt_scale=layer_decay_no_opt_scale, diff --git a/timm/optim/_param_groups.py b/timm/optim/_param_groups.py index d615669006..80b4bf6a96 100644 --- a/timm/optim/_param_groups.py +++ b/timm/optim/_param_groups.py @@ -20,36 +20,36 @@ def param_groups_weight_decay( model: nn.Module, weight_decay: float = 1e-5, no_weight_decay_list: Collection[str] = (), - simple_params_list: Collection[str] = (), - simple_no_weight_decay: bool = False, + fallback_list: Collection[str] = (), + fallback_no_weight_decay: bool = False, ): - # Merge no_weight_decay into simple_params if requested - if simple_no_weight_decay: - simple_params_list = set(simple_params_list) | set(no_weight_decay_list) + # Merge no_weight_decay into fallback_list if requested + if fallback_no_weight_decay: + fallback_list = set(fallback_list) | set(no_weight_decay_list) decay = [] - decay_simple = [] + decay_fallback = [] no_decay = [] - no_decay_simple = [] + no_decay_fallback = [] for name, param in model.named_parameters(): if not param.requires_grad: continue - # Determine if this is a "simple" parameter for fallback optimizer (if available) - is_simple = _matches_pattern(name, simple_params_list) + # Determine if this is a "fallback" parameter for fallback optimizer (if available) + is_fallback = _matches_pattern(name, fallback_list) # Determine weight decay matches_pattern = _matches_pattern(name, no_weight_decay_list) if param.ndim <= 1 or name.endswith(".bias") or matches_pattern: # No weight decay - if is_simple: - no_decay_simple.append(param) + if is_fallback: + no_decay_fallback.append(param) else: no_decay.append(param) else: # With weight decay - if is_simple: - decay_simple.append(param) + if is_fallback: + decay_fallback.append(param) else: decay.append(param) @@ -58,10 +58,10 @@ def param_groups_weight_decay( groups.append({'params': no_decay, 'weight_decay': 0.}) if decay: groups.append({'params': decay, 'weight_decay': weight_decay}) - if no_decay_simple: - groups.append({'params': no_decay_simple, 'weight_decay': 0., 'simple': True}) - if decay_simple: - groups.append({'params': decay_simple, 'weight_decay': weight_decay, 'simple': True}) + if no_decay_fallback: + groups.append({'params': no_decay_fallback, 'weight_decay': 0., 'use_fallback': True}) + if decay_fallback: + groups.append({'params': decay_fallback, 'weight_decay': weight_decay, 'use_fallback': True}) return groups @@ -103,8 +103,8 @@ def param_groups_layer_decay( model: nn.Module, weight_decay: float = 0.05, no_weight_decay_list: Collection[str] = (), - simple_params_list: Collection[str] = (), - simple_no_weight_decay: bool = False, + fallback_list: Collection[str] = (), + fallback_no_weight_decay: bool = False, weight_decay_exclude_1d: bool = True, layer_decay: float = .75, min_scale: float = 0., @@ -115,9 +115,9 @@ def param_groups_layer_decay( Parameter groups for layer-wise lr decay & weight decay Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 """ - # Merge no_weight_decay into simple_params if requested - if simple_no_weight_decay: - simple_params_list = set(simple_params_list) | set(no_weight_decay_list) + # Merge no_weight_decay into fallback_list if requested + if fallback_no_weight_decay: + fallback_list = set(fallback_list) | set(no_weight_decay_list) param_group_names = {} # NOTE for debugging param_groups = {} @@ -136,8 +136,8 @@ def param_groups_layer_decay( if not param.requires_grad: continue - # Determine if this is a "simple" parameter for fallback optimizer (if available) - is_simple = _matches_pattern(name, simple_params_list) + # Determine if this is a "fallback" parameter for fallback optimizer (if available) + is_fallback = _matches_pattern(name, fallback_list) # Determine weight decay if (weight_decay_exclude_1d and param.ndim <= 1) or _matches_pattern(name, no_weight_decay_list): @@ -155,14 +155,14 @@ def param_groups_layer_decay( param.requires_grad = False continue - simple_suffix = "_simple" if is_simple else "" - group_name = "layer_%d_%s%s" % (layer_id, g_decay, simple_suffix) + fallback_suffix = "_fallback" if is_fallback else "" + group_name = "layer_%d_%s%s" % (layer_id, g_decay, fallback_suffix) if group_name not in param_groups: param_group_names[group_name] = { "lr_scale": this_scale, "weight_decay": this_decay, - "simple": is_simple, + "use_fallback": is_fallback, "param_names": [], } param_groups[group_name] = { @@ -170,8 +170,8 @@ def param_groups_layer_decay( "weight_decay": this_decay, "params": [], } - if is_simple: - param_groups[group_name]["simple"] = True + if is_fallback: + param_groups[group_name]["use_fallback"] = True param_group_names[group_name]["param_names"].append(name) param_groups[group_name]["params"].append(param) diff --git a/timm/optim/muon.py b/timm/optim/muon.py index b2f13aa404..15e0ef1b56 100644 --- a/timm/optim/muon.py +++ b/timm/optim/muon.py @@ -376,7 +376,7 @@ class Muon(torch.optim.Optimizer): """Muon - MomentUm Orthogonalized by Newton-schulz Combines Muon for 2D+ parameters (weight matrices) with AdamW for 1D parameters (biases, norms) and - parameter groups with 'simple=True' set. + parameter groups with 'use_fallback=True' set (or 'use_muon=False' for compatibility). """ def __init__( @@ -423,7 +423,7 @@ def __init__( # Manual control over parameter groups optimizer = Muon([ {'params': weight_matrices, 'lr': 0.02}, - {'params': biases, 'simple': True, 'lr': 3e-4}, # use AdamW if simple=True + {'params': biases, 'use_fallback': True, 'lr': 3e-4}, # use AdamW if use_fallback=True ]) ``` """ @@ -494,12 +494,18 @@ def step(self, closure=None): # Determine routing on first encounter (cache in state) if "use_muon" not in state: - # Check explicit simple flag first + # Check explicit flags first (support both 'use_fallback' and 'use_muon' for compatibility) reason = None - if group.get("simple", False): + if group.get("use_fallback", False): + # use_fallback=True means use AdamW (use_muon=False) state["use_muon"] = False if verbose: - reason = "simple_flag" + reason = "use_fallback_flag" + elif "use_muon" in group: + # Explicit use_muon flag for compatibility with other Muon implementations + state["use_muon"] = group["use_muon"] + if verbose: + reason = "use_muon_flag" else: # Check shape suitability if verbose: