Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,14 @@ def test_kron(optimizer):
_test_model(optimizer, dict(lr=1e-3))


@pytest.mark.parametrize('optimizer', ['muon', 'nmuon'])
def test_muon(optimizer):
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=1e-3))


@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
def test_adopt(optimizer):
_test_rosenbrock(
Expand Down Expand Up @@ -544,7 +552,7 @@ def test_lookahead_radam(optimizer):
)


def test_param_groups_layer_decay_with_end_decay():
def test_param_groups_layer_decay_with_min():
model = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
Expand All @@ -555,12 +563,12 @@ def test_param_groups_layer_decay_with_end_decay():
model,
weight_decay=0.05,
layer_decay=0.75,
end_layer_decay=0.5,
min_scale=0.5,
verbose=True
)

assert len(param_groups) > 0
# Verify layer scaling is applied with end decay
# Verify layer scaling is applied with a min scale
for group in param_groups:
assert 'lr_scale' in group
assert group['lr_scale'] <= 1.0
Expand Down
1 change: 1 addition & 0 deletions timm/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .lookahead import Lookahead
from .madgrad import MADGRAD
from .mars import Mars
from .muon import Muon
from .nadam import NAdamLegacy
from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad
Expand Down
26 changes: 26 additions & 0 deletions timm/optim/_optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .lookahead import Lookahead
from .madgrad import MADGRAD
from .mars import Mars
from .muon import Muon
from .nadam import NAdamLegacy
from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad
Expand Down Expand Up @@ -233,6 +234,7 @@ def create_optimizer(
momentum: float = 0.9,
foreach: Optional[bool] = None,
weight_decay_exclude_1d: bool = True,
simple_no_weight_decay: bool = False,
layer_decay: Optional[float] = None,
layer_decay_min_scale: Optional[float] = None,
layer_decay_no_opt_scale: Optional[float] = None,
Expand All @@ -249,6 +251,7 @@ def create_optimizer(
momentum: Momentum factor for applicable optimizers
foreach: Enable/disable foreach operation
weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
simple_no_weight_decay: If True, params in no_weight_decay list will use simple/fallback optimizer (e.g., AdamW for Muon)
layer_decay: Layer-wise learning rate decay
layer_scale_min_scale: Minimum layer scale factor clamp value
layer_scale_no_opt_scale: Layer scale below which optimization is disabled
Expand Down Expand Up @@ -276,6 +279,7 @@ def create_optimizer(
weight_decay=weight_decay,
layer_decay=layer_decay,
no_weight_decay_list=no_weight_decay,
simple_no_weight_decay=simple_no_weight_decay,
weight_decay_exclude_1d=weight_decay_exclude_1d,
min_scale=layer_decay_min_scale,
no_opt_scale=layer_decay_no_opt_scale,
Expand All @@ -286,6 +290,7 @@ def create_optimizer(
model_or_params,
weight_decay=weight_decay,
no_weight_decay_list=no_weight_decay,
simple_no_weight_decay=simple_no_weight_decay,
)
weight_decay = 0.
else:
Expand Down Expand Up @@ -871,6 +876,23 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
description='Unleashing the Power of Variance Reduction for Training Large Models',
has_betas=True,
),
OptimInfo(
name='muon',
opt_class=Muon,
description='MomentUm Orthogonalized by Newton-schulz with AdamW fallback for 1D params',
has_momentum=True,
has_eps=True,
has_betas=True,
),
OptimInfo(
name='nmuon',
opt_class=Muon,
description='MomentUm Orthogonalized by Newton-schulz with Nesterov and NAdamW fallback for 1D params',
has_momentum=True,
has_eps=True,
has_betas=True,
defaults={'nesterov': True}
),
OptimInfo(
name='novograd',
opt_class=NvNovoGrad,
Expand Down Expand Up @@ -1145,6 +1167,7 @@ def create_optimizer_v2(
momentum: float = 0.9,
foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True,
simple_no_weight_decay: bool = False,
layer_decay: Optional[float] = None,
layer_decay_min_scale: float = 0.0,
layer_decay_no_opt_scale: Optional[float] = None,
Expand Down Expand Up @@ -1172,6 +1195,8 @@ def create_optimizer_v2(
filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have
weight decay applied. Only used when model_or_params is a model and
weight_decay > 0.
simple_no_weight_decay: If True, params in model's no_weight_decay() list will use
simple/fallback optimizer for hybrid optimizers (e.g., AdamW for Muon).
layer_decay: Optional layer-wise learning rate decay factor. If provided,
learning rates will be scaled by layer_decay^(max_depth - layer_depth).
Only used when model_or_params is a model.
Expand Down Expand Up @@ -1222,6 +1247,7 @@ def create_optimizer_v2(
momentum=momentum,
foreach=foreach,
weight_decay_exclude_1d=filter_bias_and_bn,
simple_no_weight_decay=simple_no_weight_decay,
layer_decay=layer_decay,
layer_decay_min_scale=layer_decay_min_scale,
layer_decay_no_opt_scale=layer_decay_no_opt_scale,
Expand Down
74 changes: 60 additions & 14 deletions timm/optim/_param_groups.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import fnmatch
import logging
from itertools import islice
from typing import Collection, Optional
Expand All @@ -10,27 +11,59 @@
_logger = logging.getLogger(__name__)


def _matches_pattern(name: str, patterns: Collection[str]) -> bool:
"""Check if parameter name matches any pattern (supports wildcards)."""
return any(fnmatch.fnmatch(name, pattern) for pattern in patterns)


def param_groups_weight_decay(
model: nn.Module,
weight_decay: float = 1e-5,
no_weight_decay_list: Collection[str] = (),
simple_params_list: Collection[str] = (),
simple_no_weight_decay: bool = False,
):
no_weight_decay_list = set(no_weight_decay_list)
# Merge no_weight_decay into simple_params if requested
if simple_no_weight_decay:
simple_params_list = set(simple_params_list) | set(no_weight_decay_list)

decay = []
decay_simple = []
no_decay = []
no_decay_simple = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue

if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
no_decay.append(param)
# Determine if this is a "simple" parameter for fallback optimizer (if available)
is_simple = _matches_pattern(name, simple_params_list)

# Determine weight decay
matches_pattern = _matches_pattern(name, no_weight_decay_list)
if param.ndim <= 1 or name.endswith(".bias") or matches_pattern:
# No weight decay
if is_simple:
no_decay_simple.append(param)
else:
no_decay.append(param)
else:
decay.append(param)

return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]

# With weight decay
if is_simple:
decay_simple.append(param)
else:
decay.append(param)

groups = []
if no_decay:
groups.append({'params': no_decay, 'weight_decay': 0.})
if decay:
groups.append({'params': decay, 'weight_decay': weight_decay})
if no_decay_simple:
groups.append({'params': no_decay_simple, 'weight_decay': 0., 'simple': True})
if decay_simple:
groups.append({'params': decay_simple, 'weight_decay': weight_decay, 'simple': True})

return groups

def _group(it, size):
it = iter(it)
Expand Down Expand Up @@ -70,9 +103,10 @@ def param_groups_layer_decay(
model: nn.Module,
weight_decay: float = 0.05,
no_weight_decay_list: Collection[str] = (),
simple_params_list: Collection[str] = (),
simple_no_weight_decay: bool = False,
weight_decay_exclude_1d: bool = True,
layer_decay: float = .75,
end_layer_decay: Optional[float] = None,
min_scale: float = 0.,
no_opt_scale: Optional[float] = None,
verbose: bool = False,
Expand All @@ -81,7 +115,10 @@ def param_groups_layer_decay(
Parameter groups for layer-wise lr decay & weight decay
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
"""
no_weight_decay_list = set(no_weight_decay_list)
# Merge no_weight_decay into simple_params if requested
if simple_no_weight_decay:
simple_params_list = set(simple_params_list) | set(no_weight_decay_list)

param_group_names = {} # NOTE for debugging
param_groups = {}

Expand All @@ -99,8 +136,12 @@ def param_groups_layer_decay(
if not param.requires_grad:
continue

# no decay: all 1D parameters and model specific ones
if (weight_decay_exclude_1d and param.ndim <= 1) or name in no_weight_decay_list:
# Determine if this is a "simple" parameter for fallback optimizer (if available)
is_simple = _matches_pattern(name, simple_params_list)

# Determine weight decay
if (weight_decay_exclude_1d and param.ndim <= 1) or _matches_pattern(name, no_weight_decay_list):
# no weight decay for 1D parameters and model specific ones
g_decay = "no_decay"
this_decay = 0.
else:
Expand All @@ -114,18 +155,23 @@ def param_groups_layer_decay(
param.requires_grad = False
continue

group_name = "layer_%d_%s" % (layer_id, g_decay)
simple_suffix = "_simple" if is_simple else ""
group_name = "layer_%d_%s%s" % (layer_id, g_decay, simple_suffix)

if group_name not in param_groups:
param_group_names[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"simple": is_simple,
"param_names": [],
}
param_groups[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
if is_simple:
param_groups[group_name]["simple"] = True

param_group_names[group_name]["param_names"].append(name)
param_groups[group_name]["params"].append(param)
Expand Down
Loading