Skip to content

Commit

Permalink
bugs fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Dec 24, 2018
1 parent c74a7d0 commit 173ddb3
Show file tree
Hide file tree
Showing 8 changed files with 510 additions and 62 deletions.
11 changes: 6 additions & 5 deletions jdit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Union
from collections import OrderedDict
from types import FunctionType
from typing import List, Optional


class _cached_property(object):
Expand Down Expand Up @@ -92,6 +91,7 @@ class Model(object):
>>> output = net(input_tensor)
"""

def __init__(self, proto_model: Module,
gpu_ids_abs: Union[list, tuple] = (),
init_method: Union[str, FunctionType, None] = "kaiming",
Expand Down Expand Up @@ -244,14 +244,16 @@ def check_point(self, model_name: str, epoch: int, logdir="log"):
weights = self._fix_weights(self.model.state_dict(), "remove", False) # try to remove '.module' in keys.
save(weights, model_weights_path)

def check_point_epoch(self, model_name: str, epoch: int, logdir="log"):
def is_checkpoint(self, model_name: str, epoch: int, logdir="log"):
if not self.check_point_pos:
return False
if isinstance(epoch, int):
is_check_point = epoch > 0 and (epoch % self.check_point_pos) == 0
else:
is_check_point = epoch in self.check_point_pos
if is_check_point:
self.check_point(model_name, epoch, logdir)
return is_check_point
return is_check_point

@staticmethod
def count_params(proto_model: Module):
Expand Down Expand Up @@ -320,7 +322,7 @@ def _fix_weights(weights: Union[dict, OrderedDict], fix_type: str = "remove", is
if is_strict:
assert not k.startswith("module."), "The key of weights dict is %s. Can not add 'module.'" % k
if not k.startswith("module."):
name = "module.".join(k) # add `module.`
name = "module."+ k # add `module.`
else:
name = k
else:
Expand Down Expand Up @@ -368,7 +370,6 @@ def configure(self):
return config_dic



if __name__ == '__main__':
from torch.nn import Sequential

Expand Down
58 changes: 35 additions & 23 deletions jdit/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding=utf-8
from typing import Optional, Union
import torch.optim as optim
from inspect import signature


class Optimizer(object):
Expand Down Expand Up @@ -93,28 +94,44 @@ class Optimizer(object):
lr: 1
weight_decay: 1e-05
)
>>> opt.use_decay(1)
>>> opt.is_lrdecay(1)
False
>>> opt.use_decay(10)
>>> opt.is_lrdecay(10)
True
>>> opt.use_decay(20)
>>> opt.is_lrdecay(20)
True
"""

def __init__(self, params, optimizer: "[Adam,RMSprop,SGD]", lr_decay=0.92, decay_position: Union[int, list] = None,
decay_type: "['epoch','step']" = "epoch", lr_minimum=1e-5,**kwargs):
assert isinstance(decay_position,
(int, tuple, list)), "`decay_position` should be int or tuple/list, get %s instead" % type(
decay_position)
assert decay_type in ['epoch', 'step'], "You need to set `decay_type` 'step' or 'epoch'"
def __init__(self, params: "parameters of model",
optimizer: "[Adam,RMSprop,SGD...]",
lr_decay=0.92,
decay_position: Union[int, list] = None,
decay_type: "['epoch','step']" = "epoch",
lr_minimum=1e-5,
**kwargs):
if not isinstance(decay_position, (int, tuple, list)):
raise TypeError("`decay_position` should be int or tuple/list, get %s instead" % type(
decay_position))
if decay_type not in ['epoch', 'step']:
raise AttributeError("You need to set `decay_type` 'step' or 'epoch', get %s instead" % decay_type)
self.lr_decay = lr_decay
self.lr_minimum = lr_minimum
self.decay_position = decay_position
self.decay_type = decay_type
self.opt_name = optimizer
Optim = getattr(optim, optimizer)
self.opt = Optim(filter(lambda p: p.requires_grad, params), **kwargs)

try:
Optim = getattr(optim, optimizer)
self.opt = Optim(filter(lambda p: p.requires_grad, params), **kwargs)
except TypeError as e:
raise TypeError(
"%s\n`%s` parameters are:\n %s\n Got %s instead." % (e, optimizer, signature(self.opt), kwargs))
except AttributeError as e:
opts = [i for i in dir(optim) if not i.endswith("__") and i not in ['lr_scheduler', 'Optimizer']]
raise AttributeError(
"%s\n`%s` is not an optimizer in torch.optim. Availible optims are:\n%s" % (e, optimizer, opts))

for param_group in self.opt.param_groups:
self.lr = param_group["lr"]

Expand All @@ -130,23 +147,21 @@ def __getattr__(self, name):

return getattr(self.opt, name)

def use_decay(self, position: Optional[int]) -> bool:
def is_lrdecay(self, position: Optional[int]) -> bool:
"""Judge if use learning decay on this position.
:param position: (int) A position of step or epoch.
:return: bool
"""
if not self.decay_position:
return False
assert isinstance(position, int)
if isinstance(self.decay_position, int):
is_change_lr = position > 0 and (position % self.decay_position) == 0
else:
is_change_lr = position in self.decay_position
return is_change_lr

# def update_state(self, position: int):
# if self.use_decay(position):
# self.do_lr_decay()

def do_lr_decay(self, reset_lr_decay: float = None, reset_lr: float = None):
"""Do learning rate decay, or reset them.
Expand Down Expand Up @@ -200,14 +215,11 @@ def configure(self):
print(opt.configure['lr_decay'])
opt.do_lr_decay(reset_lr=0.2)
print(opt.configure)
print(opt.use_decay(1))
print(opt.use_decay(2))
print(opt.use_decay(40))
print(opt.use_decay(10))
print(opt.is_lrdecay(1))
print(opt.is_lrdecay(2))
print(opt.is_lrdecay(40))
print(opt.is_lrdecay(10))
param = torch.nn.Linear(10, 1).parameters()
hpd = {"optimizer": "Adam", "lr_decay": 0.1, "decay_position": [1, 3, 5], "decay_type": "epoch",
"lr": 0.9, "betas": (0.9, 0.999), "weight_decay": 1e-5}
opt = Optimizer(param, **hpd)
print(opt.update_state(1), opt.opt)
print(opt.update_state(3), opt.opt)
print(opt.update_state(40), opt.lr)

0 comments on commit 173ddb3

Please sign in to comment.