Skip to content

Commit

Permalink
Merge pull request #253 from kozistr/feature/adam-mini-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement AdamMini optimizer
  • Loading branch information
kozistr committed Jul 6, 2024
2 parents 5db0994 + a970453 commit 83a2f5e
Show file tree
Hide file tree
Showing 13 changed files with 396 additions and 13 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
Currently, **72 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -168,6 +168,7 @@ supported_optimizers = get_supported_optimizers()
| Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | <https://arxiv.org/abs/2405.20233> | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) |
| Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | <https://arxiv.org/abs/2403.02648> | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) |
| StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | <https://arxiv.org/abs/2304.13013> | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) |
| AdamMini | *Use Fewer Learning Rates To Gain More* | [github](https://github.com/zyushun/Adam-mini) | <https://arxiv.org/abs/2406.16793> | [cite](https://github.com/zyushun/Adam-mini?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
2 changes: 2 additions & 0 deletions docs/changelogs/v3.0.2.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
* [Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad](https://arxiv.org/abs/2403.02648)
* Implement `StableAdamW` optimizer. (#250, #252)
* [Stable and low-precision training for large-scale vision-language models](https://arxiv.org/abs/2304.13013)
* Implement `AdamMini` optimizer. (#246, #253)
* [Use Fewer Learning Rates To Gain More](https://arxiv.org/abs/2406.16793)

### Refactor

Expand Down
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
Currently, **72 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -168,6 +168,7 @@ supported_optimizers = get_supported_optimizers()
| Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | <https://arxiv.org/abs/2405.20233> | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) |
| Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | <https://arxiv.org/abs/2403.02648> | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) |
| StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | <https://arxiv.org/abs/2304.13013> | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) |
| AdamMini | *Use Fewer Learning Rates To Gain More* | [github](https://github.com/zyushun/Adam-mini) | <https://arxiv.org/abs/2406.16793> | [cite](https://github.com/zyushun/Adam-mini?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
:docstring:
:members:

::: pytorch_optimizer.AdamMini
:docstring:
:members:

::: pytorch_optimizer.AdaMax
:docstring:
:members:
Expand Down
15 changes: 8 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
keywords = [
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite",
"AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad",
"DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", "GaLore", "Gravity",
"GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad",
"PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM",
"ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH",
"SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
"AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME",
"DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", "GaLore",
"Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero",
"NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad",
"SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3",
"SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine",
"SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down Expand Up @@ -122,6 +122,7 @@ testpaths = "tests"
[tool.coverage.run]
omit = [
"./pytorch_optimizer/optimizer/rotograd.py",
"./pytorch_optimizer/optimizer/adam_mini.py",
]

[build-system]
Expand Down
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from pytorch_optimizer.optimizer.adahessian import AdaHessian
from pytorch_optimizer.optimizer.adai import Adai
from pytorch_optimizer.optimizer.adalite import Adalite
from pytorch_optimizer.optimizer.adam_mini import AdamMini
from pytorch_optimizer.optimizer.adamax import AdaMax
from pytorch_optimizer.optimizer.adamod import AdaMod
from pytorch_optimizer.optimizer.adamp import AdamP
Expand Down Expand Up @@ -203,6 +204,7 @@
GrokFastAdamW,
Kate,
StableAdamW,
AdamMini,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
5 changes: 5 additions & 0 deletions pytorch_optimizer/base/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ def validate_learning_rate(learning_rate: Optional[float]) -> None:
if learning_rate is not None and learning_rate < 0.0:
raise NegativeLRError(learning_rate)

@staticmethod
def validate_mod(x: int, y: int) -> None:
if x % y != 0:
raise ValueError(f'[-] {x} must be divisible by {y}')

def validate_betas(self, betas: BETAS) -> None:
if betas[0] is not None:
self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]')
Expand Down
Loading

0 comments on commit 83a2f5e

Please sign in to comment.