-
Notifications
You must be signed in to change notification settings - Fork 9.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add optimizer registry * move under core, add doc * unexpose TORCH_OPTIMIZERS
- Loading branch information
Showing
7 changed files
with
155 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .builder import build_optimizer | ||
from .copy_of_sgd import CopyOfSGD | ||
from .registry import OPTIMIZERS | ||
|
||
__all__ = ['OPTIMIZERS', 'build_optimizer', 'CopyOfSGD'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import re | ||
|
||
from mmdet.utils import build_from_cfg | ||
from .registry import OPTIMIZERS | ||
|
||
|
||
def build_optimizer(model, optimizer_cfg): | ||
"""Build optimizer from configs. | ||
Args: | ||
model (:obj:`nn.Module`): The model with parameters to be optimized. | ||
optimizer_cfg (dict): The config dict of the optimizer. | ||
Positional fields are: | ||
- type: class name of the optimizer. | ||
- lr: base learning rate. | ||
Optional fields are: | ||
- any arguments of the corresponding optimizer type, e.g., | ||
weight_decay, momentum, etc. | ||
- paramwise_options: a dict with 3 accepted fileds | ||
(bias_lr_mult, bias_decay_mult, norm_decay_mult). | ||
`bias_lr_mult` and `bias_decay_mult` will be multiplied to | ||
the lr and weight decay respectively for all bias parameters | ||
(except for the normalization layers), and | ||
`norm_decay_mult` will be multiplied to the weight decay | ||
for all weight and bias parameters of normalization layers. | ||
Returns: | ||
torch.optim.Optimizer: The initialized optimizer. | ||
Example: | ||
>>> import torch | ||
>>> model = torch.nn.modules.Conv1d(1, 1, 1) | ||
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, | ||
>>> weight_decay=0.0001) | ||
>>> optimizer = build_optimizer(model, optimizer_cfg) | ||
""" | ||
if hasattr(model, 'module'): | ||
model = model.module | ||
|
||
optimizer_cfg = optimizer_cfg.copy() | ||
paramwise_options = optimizer_cfg.pop('paramwise_options', None) | ||
# if no paramwise option is specified, just use the global setting | ||
if paramwise_options is None: | ||
params = model.parameters() | ||
else: | ||
assert isinstance(paramwise_options, dict) | ||
# get base lr and weight decay | ||
base_lr = optimizer_cfg['lr'] | ||
base_wd = optimizer_cfg.get('weight_decay', None) | ||
# weight_decay must be explicitly specified if mult is specified | ||
if ('bias_decay_mult' in paramwise_options | ||
or 'norm_decay_mult' in paramwise_options): | ||
assert base_wd is not None | ||
# get param-wise options | ||
bias_lr_mult = paramwise_options.get('bias_lr_mult', 1.) | ||
bias_decay_mult = paramwise_options.get('bias_decay_mult', 1.) | ||
norm_decay_mult = paramwise_options.get('norm_decay_mult', 1.) | ||
# set param-wise lr and weight decay | ||
params = [] | ||
for name, param in model.named_parameters(): | ||
param_group = {'params': [param]} | ||
if not param.requires_grad: | ||
# FP16 training needs to copy gradient/weight between master | ||
# weight copy and model weight, it is convenient to keep all | ||
# parameters here to align with model.parameters() | ||
params.append(param_group) | ||
continue | ||
|
||
# for norm layers, overwrite the weight decay of weight and bias | ||
# TODO: obtain the norm layer prefixes dynamically | ||
if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name): | ||
if base_wd is not None: | ||
param_group['weight_decay'] = base_wd * norm_decay_mult | ||
# for other layers, overwrite both lr and weight decay of bias | ||
elif name.endswith('.bias'): | ||
param_group['lr'] = base_lr * bias_lr_mult | ||
if base_wd is not None: | ||
param_group['weight_decay'] = base_wd * bias_decay_mult | ||
# otherwise use the global settings | ||
|
||
params.append(param_group) | ||
|
||
optimizer_cfg['params'] = params | ||
|
||
return build_from_cfg(optimizer_cfg, OPTIMIZERS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from torch.optim import SGD | ||
|
||
from .registry import OPTIMIZERS | ||
|
||
|
||
@OPTIMIZERS.register_module | ||
class CopyOfSGD(SGD): | ||
"""A clone of torch.optim.SGD. | ||
A customized optimizer could be defined like CopyOfSGD. | ||
You may derive from built-in optimizers in torch.optim, | ||
or directly implement a new optimizer. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import inspect | ||
|
||
import torch | ||
|
||
from mmdet.utils import Registry | ||
|
||
OPTIMIZERS = Registry('optimizer') | ||
|
||
|
||
def register_torch_optimizers(): | ||
torch_optimizers = [] | ||
for module_name in dir(torch.optim): | ||
if module_name.startswith('__'): | ||
continue | ||
_optim = getattr(torch.optim, module_name) | ||
if inspect.isclass(_optim) and issubclass(_optim, | ||
torch.optim.Optimizer): | ||
OPTIMIZERS.register_module(_optim) | ||
torch_optimizers.append(module_name) | ||
return torch_optimizers | ||
|
||
|
||
TORCH_OPTIMIZERS = register_torch_optimizers() |