Skip to content

Commit

Permalink
feat: support torch.hub
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing-su committed Sep 1, 2022
1 parent 569bf89 commit 74caee6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
10 changes: 10 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
dependencies = ["torch"]

from functools import partial

from pytorch_optimizer import get_supported_optimizers, load_optimizer

for optimizer in get_supported_optimizers():
name = optimizer.__name__
for n in (name, name.lower()):
globals()[n] = partial(load_optimizer, optimizer=n)
10 changes: 6 additions & 4 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# pylint: disable=unused-import
from typing import Callable, Dict, List
from typing import Dict, List, Type

from torch.optim import Optimizer

from pytorch_optimizer.adabelief import AdaBelief
from pytorch_optimizer.adabound import AdaBound
Expand Down Expand Up @@ -54,10 +56,10 @@
SGDP,
Shampoo,
]
OPTIMIZERS: Dict[str, Callable] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
OPTIMIZERS: Dict[str, Type[Optimizer]] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}


def load_optimizer(optimizer: str) -> Callable:
def load_optimizer(optimizer: str) -> Type[Optimizer]:
optimizer: str = optimizer.lower()

if optimizer not in OPTIMIZERS:
Expand All @@ -66,5 +68,5 @@ def load_optimizer(optimizer: str) -> Callable:
return OPTIMIZERS[optimizer]


def get_supported_optimizers() -> List:
def get_supported_optimizers() -> List[Type[Optimizer]]:
return OPTIMIZER_LIST

0 comments on commit 74caee6

Please sign in to comment.