Skip to content

Commit

Permalink
fix optimwrapper to work with param_groups (closes #2829) (#3241)
Browse files Browse the repository at this point in the history
* fix optimwrapper to work with param_groups

* change optimwrapper to make it even easier to use

* Update fastai/optimizer.py

Co-authored-by: Jeremy Howard <github@jhoward.fastmail.fm>

* incorporate jeremy's suggestions and add #3226 doc improvements

* add #slow flag to cell in 14_callback.schedule.ipynb

* switch to #cuda

Co-authored-by: Jeremy Howard <github@jhoward.fastmail.fm>
  • Loading branch information
tmabraham and jph00 committed Mar 6, 2021
1 parent 87827a1 commit 89b05f5
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 320 deletions.
15 changes: 12 additions & 3 deletions fastai/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,14 +340,23 @@ def set_item_pg(pg, k, v):
# Cell
pytorch_hp_map = {'momentum': 'mom', 'weight_decay': 'wd', 'alpha': 'sqr_mom', 'betas__0': 'mom', 'betas__1': 'sqr_mom'}

# Cell
def _convert_params(o:list) -> list:
splitter = []
for group in o:
if isinstance(group, dict): splitter.append(group)
else: splitter.append({'params':group})
return splitter

# Cell
class OptimWrapper(_BaseOptimizer, GetAttr):
"A wrapper class for existing PyTorch optimizers"
_xtra=['zero_grad', 'step', 'state_dict', 'load_state_dict']
_default='opt'
def __init__(self, opt, hp_map=None):
self.opt = opt
def __init__(self, params, opt, hp_map=None, convert_groups=True, **kwargs):
self.opt = opt(_convert_params(params), **kwargs) if convert_groups else opt(params, **kwargs)
if hp_map is None: hp_map = pytorch_hp_map
self.fwd_map = {k: hp_map[k] if k in hp_map else k for k in detuplify_pg(opt.param_groups[0]).keys()}
self.fwd_map = {k: hp_map[k] if k in hp_map else k for k in detuplify_pg(self.opt.param_groups[0]).keys()}
self.bwd_map = {v:k for k,v in self.fwd_map.items()}
self.state = defaultdict(dict, {})
self.frozen_idx = 0
Expand Down
Loading

0 comments on commit 89b05f5

Please sign in to comment.