Skip to content

ironjr/gap

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Greedy Alignment Principle for Optimizer Selection

Official reference implementation for paper "Greedy Alignment Principle for Optimizer Selection".

PyTorch optimizers — Adam (with AdamW), SGDM — that pick β₁ (or m) online from a small candidate list using Frobenius-scaled J:

Adam:  J_j = (g · update_j) · √((1 − β₁_j²) / W),   W = mean_p (1 / (√v_p + ε))
SGDM:  J_j = (g · buf_j)    · √(1 − m_j²)

The candidate with the largest J is selected each step.

Files

adam.py          # Adam / AdamW
sgd.py           # SGD + momentum
__init__.py      # exports Adam, SGDM
test_optim.py    # unit tests
README.md

Run tests:

python -m unittest for_submission.test_optim -v

Modes

type Behavior
'default' Standard Adam(W) / SGD+momentum (matches torch.optim.{Adam, AdamW, SGD} exactly).
'switch' Global pick at each step; all parameters share one chosen index, computed from summed J. Supports J_ema.
'switch_parameter_wise' Each parameter independently picks its candidate.

J_ema ∈ [0, 1) smooths selection across steps for global 'switch': J_ema · J_prev + (1 − J_ema) · J_curr. Default 0.0 (off). Hysteresis always uses raw J.

Adam supports AdamW via decoupled_weight_decay=True in every mode.

Recommended configs

# Vision LoRA fine-tuning (ViT/DINO + LoRA)
Adam(p, lr=1e-3, type='switch_parameter_wise',
     betas=[(0.5, 0.999), (0.99, 0.999)],
     decoupled_weight_decay=True, weight_decay=0.01)

# LLM LoRA fine-tuning
Adam(p, lr=2e-4, type='switch_parameter_wise',
     betas=[(0.8, 0.999), (0.99, 0.999)],
     decoupled_weight_decay=True)

# Global switch with smoothing (small batch / noisy gradients)
Adam(p, lr=1e-3, type='switch', J_ema=0.9,
     betas=[(0.5, 0.999), (0.9, 0.999), (0.99, 0.999)],
     decoupled_weight_decay=True, weight_decay=0.01)

# SGD baseline replacement
SGDM(p, lr=0.1, type='switch', momentum=[0.5, 0.9, 0.99])

API

Adam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0,
     decoupled_weight_decay=False,
     type='default',                            # default | switch | switch_parameter_wise
     lr_modulation=False,
     soft_reset=0.9, hysteresis_threshold=10,
     second_order_momentum_reset_mode='none',   # none | same | square
     J_ema=0.0)                                 # only for type='switch'

SGDM(params, lr=1e-3, momentum=0, dampening=0, weight_decay=0, nesterov=False,
     type='default',                            # default | switch | switch_parameter_wise
     lr_modulation=False,
     soft_reset=0.5, hysteresis_threshold=5,
     J_ema=0.0)                                 # only for type='switch'

betas: tuple (β₁, β₂) for 'default', list [(β₁, β₂), ...] for switch types. momentum: float for 'default', list [m₀, m₁, ...] for switch types. After step(), the chosen index is in param_group['last_choice'] (int for 'switch', list[int] for 'switch_parameter_wise').

Memory

Switch modes hold K parallel buffer sets per parameter:

default switch*
Adam 2d 2Kd
SGDM d Kd

Typical K = 2 (Adam) or K = 3 (SGDM). The optional J_ema adds K floats per parameter group.

Hysteresis

When raw J stays negative for hysteresis_threshold consecutive steps, first-moment buffers are scaled by soft_reset. The second moment is preserved by default (Adam: 'none'); set 'same' or 'square' to reset it as well.

Computation note (switch design)

The global switch uses one dot product per (param, candidate): (g · buf) · √(1 − m²). An alternative formulation ((buf·buf) − m·(buf·prev_buf)) needs two dot products and a prev_buf.clone() per candidate (~2× compute, +Kd transient memory). We unified on the cheaper Frobenius form and added J_ema to recover cross-step smoothing if desired.

Requirements

PyTorch ≥ 1.10, Python ≥ 3.8.

Citation

If you use this code, please cite:

@article{lee2025greedy,
  title={Greedy Alignment Principle for Optimizer Selection},
  author={Lee, Jaerin and Lee, Kyoung Mu},
  journal={arXiv preprint arXiv:2512.06370},
  year={2025}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages