Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions timm/optim/_optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Type, Union
from fnmatch import fnmatch
import importlib

Expand Down Expand Up @@ -234,7 +234,8 @@ def create_optimizer(
momentum: float = 0.9,
foreach: Optional[bool] = None,
weight_decay_exclude_1d: bool = True,
simple_no_weight_decay: bool = False,
fallback_list: Collection[str] = (),
fallback_no_weight_decay: bool = False,
layer_decay: Optional[float] = None,
layer_decay_min_scale: Optional[float] = None,
layer_decay_no_opt_scale: Optional[float] = None,
Expand All @@ -251,7 +252,8 @@ def create_optimizer(
momentum: Momentum factor for applicable optimizers
foreach: Enable/disable foreach operation
weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
simple_no_weight_decay: If True, params in no_weight_decay list will use simple/fallback optimizer (e.g., AdamW for Muon)
fallback_list: Collection of parameter name patterns to use fallback optimizer for hybrid optimizers
fallback_no_weight_decay: If True, params in no_weight_decay list will use fallback optimizer (e.g., AdamW for Muon)
layer_decay: Layer-wise learning rate decay
layer_scale_min_scale: Minimum layer scale factor clamp value
layer_scale_no_opt_scale: Layer scale below which optimization is disabled
Expand Down Expand Up @@ -279,7 +281,8 @@ def create_optimizer(
weight_decay=weight_decay,
layer_decay=layer_decay,
no_weight_decay_list=no_weight_decay,
simple_no_weight_decay=simple_no_weight_decay,
fallback_list=fallback_list,
fallback_no_weight_decay=fallback_no_weight_decay,
weight_decay_exclude_1d=weight_decay_exclude_1d,
min_scale=layer_decay_min_scale,
no_opt_scale=layer_decay_no_opt_scale,
Expand All @@ -290,7 +293,8 @@ def create_optimizer(
model_or_params,
weight_decay=weight_decay,
no_weight_decay_list=no_weight_decay,
simple_no_weight_decay=simple_no_weight_decay,
fallback_list=fallback_list,
fallback_no_weight_decay=fallback_no_weight_decay,
)
weight_decay = 0.
else:
Expand Down Expand Up @@ -1167,7 +1171,8 @@ def create_optimizer_v2(
momentum: float = 0.9,
foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True,
simple_no_weight_decay: bool = False,
fallback_list: Collection[str] = (),
fallback_no_weight_decay: bool = False,
layer_decay: Optional[float] = None,
layer_decay_min_scale: float = 0.0,
layer_decay_no_opt_scale: Optional[float] = None,
Expand Down Expand Up @@ -1195,8 +1200,10 @@ def create_optimizer_v2(
filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have
weight decay applied. Only used when model_or_params is a model and
weight_decay > 0.
simple_no_weight_decay: If True, params in model's no_weight_decay() list will use
simple/fallback optimizer for hybrid optimizers (e.g., AdamW for Muon).
fallback_list: Collection of parameter name patterns to use fallback optimizer for
hybrid optimizers (e.g., AdamW for Muon). Supports wildcard matching.
fallback_no_weight_decay: If True, params in model's no_weight_decay() list will use
fallback optimizer for hybrid optimizers (e.g., AdamW for Muon).
layer_decay: Optional layer-wise learning rate decay factor. If provided,
learning rates will be scaled by layer_decay^(max_depth - layer_depth).
Only used when model_or_params is a model.
Expand Down Expand Up @@ -1247,7 +1254,8 @@ def create_optimizer_v2(
momentum=momentum,
foreach=foreach,
weight_decay_exclude_1d=filter_bias_and_bn,
simple_no_weight_decay=simple_no_weight_decay,
fallback_list=fallback_list,
fallback_no_weight_decay=fallback_no_weight_decay,
layer_decay=layer_decay,
layer_decay_min_scale=layer_decay_min_scale,
layer_decay_no_opt_scale=layer_decay_no_opt_scale,
Expand Down
58 changes: 29 additions & 29 deletions timm/optim/_param_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,36 @@ def param_groups_weight_decay(
model: nn.Module,
weight_decay: float = 1e-5,
no_weight_decay_list: Collection[str] = (),
simple_params_list: Collection[str] = (),
simple_no_weight_decay: bool = False,
fallback_list: Collection[str] = (),
fallback_no_weight_decay: bool = False,
):
# Merge no_weight_decay into simple_params if requested
if simple_no_weight_decay:
simple_params_list = set(simple_params_list) | set(no_weight_decay_list)
# Merge no_weight_decay into fallback_list if requested
if fallback_no_weight_decay:
fallback_list = set(fallback_list) | set(no_weight_decay_list)

decay = []
decay_simple = []
decay_fallback = []
no_decay = []
no_decay_simple = []
no_decay_fallback = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue

# Determine if this is a "simple" parameter for fallback optimizer (if available)
is_simple = _matches_pattern(name, simple_params_list)
# Determine if this is a "fallback" parameter for fallback optimizer (if available)
is_fallback = _matches_pattern(name, fallback_list)

# Determine weight decay
matches_pattern = _matches_pattern(name, no_weight_decay_list)
if param.ndim <= 1 or name.endswith(".bias") or matches_pattern:
# No weight decay
if is_simple:
no_decay_simple.append(param)
if is_fallback:
no_decay_fallback.append(param)
else:
no_decay.append(param)
else:
# With weight decay
if is_simple:
decay_simple.append(param)
if is_fallback:
decay_fallback.append(param)
else:
decay.append(param)

Expand All @@ -58,10 +58,10 @@ def param_groups_weight_decay(
groups.append({'params': no_decay, 'weight_decay': 0.})
if decay:
groups.append({'params': decay, 'weight_decay': weight_decay})
if no_decay_simple:
groups.append({'params': no_decay_simple, 'weight_decay': 0., 'simple': True})
if decay_simple:
groups.append({'params': decay_simple, 'weight_decay': weight_decay, 'simple': True})
if no_decay_fallback:
groups.append({'params': no_decay_fallback, 'weight_decay': 0., 'use_fallback': True})
if decay_fallback:
groups.append({'params': decay_fallback, 'weight_decay': weight_decay, 'use_fallback': True})

return groups

Expand Down Expand Up @@ -103,8 +103,8 @@ def param_groups_layer_decay(
model: nn.Module,
weight_decay: float = 0.05,
no_weight_decay_list: Collection[str] = (),
simple_params_list: Collection[str] = (),
simple_no_weight_decay: bool = False,
fallback_list: Collection[str] = (),
fallback_no_weight_decay: bool = False,
weight_decay_exclude_1d: bool = True,
layer_decay: float = .75,
min_scale: float = 0.,
Expand All @@ -115,9 +115,9 @@ def param_groups_layer_decay(
Parameter groups for layer-wise lr decay & weight decay
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
"""
# Merge no_weight_decay into simple_params if requested
if simple_no_weight_decay:
simple_params_list = set(simple_params_list) | set(no_weight_decay_list)
# Merge no_weight_decay into fallback_list if requested
if fallback_no_weight_decay:
fallback_list = set(fallback_list) | set(no_weight_decay_list)

param_group_names = {} # NOTE for debugging
param_groups = {}
Expand All @@ -136,8 +136,8 @@ def param_groups_layer_decay(
if not param.requires_grad:
continue

# Determine if this is a "simple" parameter for fallback optimizer (if available)
is_simple = _matches_pattern(name, simple_params_list)
# Determine if this is a "fallback" parameter for fallback optimizer (if available)
is_fallback = _matches_pattern(name, fallback_list)

# Determine weight decay
if (weight_decay_exclude_1d and param.ndim <= 1) or _matches_pattern(name, no_weight_decay_list):
Expand All @@ -155,23 +155,23 @@ def param_groups_layer_decay(
param.requires_grad = False
continue

simple_suffix = "_simple" if is_simple else ""
group_name = "layer_%d_%s%s" % (layer_id, g_decay, simple_suffix)
fallback_suffix = "_fallback" if is_fallback else ""
group_name = "layer_%d_%s%s" % (layer_id, g_decay, fallback_suffix)

if group_name not in param_groups:
param_group_names[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"simple": is_simple,
"use_fallback": is_fallback,
"param_names": [],
}
param_groups[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
if is_simple:
param_groups[group_name]["simple"] = True
if is_fallback:
param_groups[group_name]["use_fallback"] = True

param_group_names[group_name]["param_names"].append(name)
param_groups[group_name]["params"].append(param)
Expand Down
16 changes: 11 additions & 5 deletions timm/optim/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ class Muon(torch.optim.Optimizer):
"""Muon - MomentUm Orthogonalized by Newton-schulz
Combines Muon for 2D+ parameters (weight matrices) with AdamW for 1D parameters (biases, norms) and
parameter groups with 'simple=True' set.
parameter groups with 'use_fallback=True' set (or 'use_muon=False' for compatibility).
"""

def __init__(
Expand Down Expand Up @@ -423,7 +423,7 @@ def __init__(
# Manual control over parameter groups
optimizer = Muon([
{'params': weight_matrices, 'lr': 0.02},
{'params': biases, 'simple': True, 'lr': 3e-4}, # use AdamW if simple=True
{'params': biases, 'use_fallback': True, 'lr': 3e-4}, # use AdamW if use_fallback=True
])
```
"""
Expand Down Expand Up @@ -494,12 +494,18 @@ def step(self, closure=None):

# Determine routing on first encounter (cache in state)
if "use_muon" not in state:
# Check explicit simple flag first
# Check explicit flags first (support both 'use_fallback' and 'use_muon' for compatibility)
reason = None
if group.get("simple", False):
if group.get("use_fallback", False):
# use_fallback=True means use AdamW (use_muon=False)
state["use_muon"] = False
if verbose:
reason = "simple_flag"
reason = "use_fallback_flag"
elif "use_muon" in group:
# Explicit use_muon flag for compatibility with other Muon implementations
state["use_muon"] = group["use_muon"]
if verbose:
reason = "use_muon_flag"
else:
# Check shape suitability
if verbose:
Expand Down