Official reference implementation for paper "Greedy Alignment Principle for Optimizer Selection".
PyTorch optimizers — Adam (with AdamW), SGDM — that pick β₁ (or m) online from a small candidate list using Frobenius-scaled J:
Adam: J_j = (g · update_j) · √((1 − β₁_j²) / W), W = mean_p (1 / (√v_p + ε))
SGDM: J_j = (g · buf_j) · √(1 − m_j²)
The candidate with the largest J is selected each step.
adam.py # Adam / AdamW
sgd.py # SGD + momentum
__init__.py # exports Adam, SGDM
test_optim.py # unit tests
README.md
Run tests:
python -m unittest for_submission.test_optim -vtype |
Behavior |
|---|---|
'default' |
Standard Adam(W) / SGD+momentum (matches torch.optim.{Adam, AdamW, SGD} exactly). |
'switch' |
Global pick at each step; all parameters share one chosen index, computed from summed J. Supports J_ema. |
'switch_parameter_wise' |
Each parameter independently picks its candidate. |
J_ema ∈ [0, 1) smooths selection across steps for global 'switch': J_ema · J_prev + (1 − J_ema) · J_curr. Default 0.0 (off). Hysteresis always uses raw J.
Adam supports AdamW via decoupled_weight_decay=True in every mode.
# Vision LoRA fine-tuning (ViT/DINO + LoRA)
Adam(p, lr=1e-3, type='switch_parameter_wise',
betas=[(0.5, 0.999), (0.99, 0.999)],
decoupled_weight_decay=True, weight_decay=0.01)
# LLM LoRA fine-tuning
Adam(p, lr=2e-4, type='switch_parameter_wise',
betas=[(0.8, 0.999), (0.99, 0.999)],
decoupled_weight_decay=True)
# Global switch with smoothing (small batch / noisy gradients)
Adam(p, lr=1e-3, type='switch', J_ema=0.9,
betas=[(0.5, 0.999), (0.9, 0.999), (0.99, 0.999)],
decoupled_weight_decay=True, weight_decay=0.01)
# SGD baseline replacement
SGDM(p, lr=0.1, type='switch', momentum=[0.5, 0.9, 0.99])Adam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0,
decoupled_weight_decay=False,
type='default', # default | switch | switch_parameter_wise
lr_modulation=False,
soft_reset=0.9, hysteresis_threshold=10,
second_order_momentum_reset_mode='none', # none | same | square
J_ema=0.0) # only for type='switch'
SGDM(params, lr=1e-3, momentum=0, dampening=0, weight_decay=0, nesterov=False,
type='default', # default | switch | switch_parameter_wise
lr_modulation=False,
soft_reset=0.5, hysteresis_threshold=5,
J_ema=0.0) # only for type='switch'betas: tuple (β₁, β₂) for 'default', list [(β₁, β₂), ...] for switch types.
momentum: float for 'default', list [m₀, m₁, ...] for switch types.
After step(), the chosen index is in param_group['last_choice'] (int for 'switch', list[int] for 'switch_parameter_wise').
Switch modes hold K parallel buffer sets per parameter:
default |
switch* |
|
|---|---|---|
| Adam | 2d |
2Kd |
| SGDM | d |
Kd |
Typical K = 2 (Adam) or K = 3 (SGDM). The optional J_ema adds K floats per parameter group.
When raw J stays negative for hysteresis_threshold consecutive steps, first-moment buffers are scaled by soft_reset. The second moment is preserved by default (Adam: 'none'); set 'same' or 'square' to reset it as well.
The global switch uses one dot product per (param, candidate): (g · buf) · √(1 − m²). An alternative formulation ((buf·buf) − m·(buf·prev_buf)) needs two dot products and a prev_buf.clone() per candidate (~2× compute, +Kd transient memory). We unified on the cheaper Frobenius form and added J_ema to recover cross-step smoothing if desired.
PyTorch ≥ 1.10, Python ≥ 3.8.
If you use this code, please cite:
@article{lee2025greedy,
title={Greedy Alignment Principle for Optimizer Selection},
author={Lee, Jaerin and Lee, Kyoung Mu},
journal={arXiv preprint arXiv:2512.06370},
year={2025}
}